"...git@developer.sourcefind.cn:change/sglang.git" did not exist on "0ac61146947ad5bb202ce08a81431eb0daf43aef"
Commit 3324d816 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Implement embedding based similarity mask head for DeepMAC.

PiperOrigin-RevId: 394512536
parent 0daae829
...@@ -263,7 +263,8 @@ def _build_classification_loss(loss_config): ...@@ -263,7 +263,8 @@ def _build_classification_loss(loss_config):
elif loss_type == 'weighted_dice_classification_loss': elif loss_type == 'weighted_dice_classification_loss':
config = loss_config.weighted_dice_classification_loss config = loss_config.weighted_dice_classification_loss
return losses.WeightedDiceClassificationLoss( return losses.WeightedDiceClassificationLoss(
squared_normalization=config.squared_normalization) squared_normalization=config.squared_normalization,
is_prediction_probability=config.is_prediction_probability)
else: else:
raise ValueError('Empty loss config.') raise ValueError('Empty loss config.')
...@@ -286,15 +286,19 @@ class WeightedDiceClassificationLoss(Loss): ...@@ -286,15 +286,19 @@ class WeightedDiceClassificationLoss(Loss):
""" """
def __init__(self, squared_normalization): def __init__(self, squared_normalization, is_prediction_probability=False):
"""Initializes the loss object. """Initializes the loss object.
Args: Args:
squared_normalization: boolean, if set, we square the probabilities in the squared_normalization: boolean, if set, we square the probabilities in the
denominator term used for normalization. denominator term used for normalization.
is_prediction_probability: boolean, whether or not the input
prediction_tensor represents a probability. If false, it is
first converted to a probability by applying sigmoid.
""" """
self._squared_normalization = squared_normalization self._squared_normalization = squared_normalization
self.is_prediction_probability = is_prediction_probability
super(WeightedDiceClassificationLoss, self).__init__() super(WeightedDiceClassificationLoss, self).__init__()
def _compute_loss(self, def _compute_loss(self,
...@@ -332,6 +336,9 @@ class WeightedDiceClassificationLoss(Loss): ...@@ -332,6 +336,9 @@ class WeightedDiceClassificationLoss(Loss):
tf.shape(prediction_tensor)[2]), tf.shape(prediction_tensor)[2]),
[1, 1, -1]) [1, 1, -1])
if self.is_prediction_probability:
prob_tensor = prediction_tensor
else:
prob_tensor = tf.nn.sigmoid(prediction_tensor) prob_tensor = tf.nn.sigmoid(prediction_tensor)
if self._squared_normalization: if self._squared_normalization:
......
...@@ -36,7 +36,8 @@ class DeepMACParams( ...@@ -36,7 +36,8 @@ class DeepMACParams(
'allowed_masked_classes_ids', 'mask_size', 'mask_num_subsamples', 'allowed_masked_classes_ids', 'mask_size', 'mask_num_subsamples',
'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels', 'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels',
'predict_full_resolution_masks', 'postprocess_crop_size', 'predict_full_resolution_masks', 'postprocess_crop_size',
'max_roi_jitter_ratio', 'roi_jitter_mode', 'box_consistency_loss_weight' 'max_roi_jitter_ratio', 'roi_jitter_mode',
'box_consistency_loss_weight',
])): ])):
"""Class holding the DeepMAC network configutration.""" """Class holding the DeepMAC network configutration."""
...@@ -125,6 +126,9 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None): ...@@ -125,6 +126,9 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
raise ValueError('Mask size must be set.') raise ValueError('Mask size must be set.')
return FullyConnectedMaskHead(num_init_channels, mask_size) return FullyConnectedMaskHead(num_init_channels, mask_size)
elif name == 'embedding_distance_probability':
return tf.keras.layers.Lambda(lambda x: x)
elif name.startswith('resnet'): elif name.startswith('resnet'):
return ResNetMaskNetwork(name, num_init_channels) return ResNetMaskNetwork(name, num_init_channels)
...@@ -262,6 +266,25 @@ def fill_boxes(boxes, height, width): ...@@ -262,6 +266,25 @@ def fill_boxes(boxes, height, width):
return tf.cast(filled_boxes, tf.float32) return tf.cast(filled_boxes, tf.float32)
def embedding_distance_to_probability(x, y):
"""Compute probability based on pixel-wise embedding distance.
Args:
x: [num_instances, height, width, dimension] float tensor input.
y: [num_instances, height, width, dimension] or
[num_instances, 1, 1, dimension] float tensor input. When the height
and width dimensions are 1, TF will broadcast it.
Returns:
dist: [num_instances, height, width, 1] A float tensor returning
the per-pixel probability. Pixels whose embeddings are close in
euclidean distance get a probability of close to 1.
"""
diff = x - y
squared_dist = tf.reduce_sum(diff * diff, axis=3, keepdims=True)
return tf.exp(-squared_dist)
class ResNetMaskNetwork(tf.keras.layers.Layer): class ResNetMaskNetwork(tf.keras.layers.Layer):
"""A small wrapper around ResNet blocks to predict masks.""" """A small wrapper around ResNet blocks to predict masks."""
...@@ -366,6 +389,16 @@ class MaskHeadNetwork(tf.keras.layers.Layer): ...@@ -366,6 +389,16 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
network_type, num_init_channels, mask_size) network_type, num_init_channels, mask_size)
self._use_instance_embedding = use_instance_embedding self._use_instance_embedding = use_instance_embedding
self._network_type = network_type
if (self._use_instance_embedding and
(self._network_type == 'embedding_distance_probability')):
raise ValueError(('Cannot feed instance embedding to mask head when '
'computing distance from instance embedding.'))
if network_type == 'embedding_distance_probability':
self.project_out = tf.keras.layers.Lambda(lambda x: x)
else:
self.project_out = tf.keras.layers.Conv2D( self.project_out = tf.keras.layers.Conv2D(
filters=1, kernel_size=1, activation=None) filters=1, kernel_size=1, activation=None)
...@@ -388,10 +421,9 @@ class MaskHeadNetwork(tf.keras.layers.Layer): ...@@ -388,10 +421,9 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
height = tf.shape(pixel_embedding)[1] height = tf.shape(pixel_embedding)[1]
width = tf.shape(pixel_embedding)[2] width = tf.shape(pixel_embedding)[2]
if self._use_instance_embedding:
instance_embedding = instance_embedding[:, tf.newaxis, tf.newaxis, :] instance_embedding = instance_embedding[:, tf.newaxis, tf.newaxis, :]
instance_embedding = tf.tile(instance_embedding, [1, height, width, 1]) instance_embedding = tf.tile(instance_embedding, [1, height, width, 1])
if self._use_instance_embedding:
inputs = tf.concat([pixel_embedding, instance_embedding], axis=3) inputs = tf.concat([pixel_embedding, instance_embedding], axis=3)
else: else:
inputs = pixel_embedding inputs = pixel_embedding
...@@ -400,6 +432,10 @@ class MaskHeadNetwork(tf.keras.layers.Layer): ...@@ -400,6 +432,10 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
if isinstance(out, list): if isinstance(out, list):
out = out[-1] out = out[-1]
if self._network_type == 'embedding_distance_probability':
instance_embedding = instance_embedding[:, tf.newaxis, tf.newaxis, :]
out = embedding_distance_to_probability(instance_embedding, out)
if out.shape[-1] > 1: if out.shape[-1] > 1:
out = self.project_out(out) out = self.project_out(out)
...@@ -466,6 +502,25 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -466,6 +502,25 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
if self._deepmac_params.mask_num_subsamples > 0: if self._deepmac_params.mask_num_subsamples > 0:
raise ValueError('Subsampling masks is currently not supported.') raise ValueError('Subsampling masks is currently not supported.')
if self._deepmac_params.network_type == 'embedding_distance_probability':
if self._deepmac_params.use_xy:
raise ValueError(
'Cannot use x/y coordinates when using embedding distance.')
pixel_embedding_dim = self._deepmac_params.pixel_embedding_dim
dim = self._deepmac_params.dim
if dim != pixel_embedding_dim:
raise ValueError(
'When using embedding distance mask head, '
f'pixel_embedding_dim({pixel_embedding_dim}) '
f'must be same as dim({dim}).')
loss = self._deepmac_params.classification_loss
if ((not isinstance(loss, losses.WeightedDiceClassificationLoss))
or (not loss.is_prediction_probability)):
raise ValueError('Only dice loss with is_prediction_probability=true '
'is supported with embedding distance mask head.')
super(DeepMACMetaArch, self).__init__( super(DeepMACMetaArch, self).__init__(
is_training=is_training, add_summaries=add_summaries, is_training=is_training, add_summaries=add_summaries,
num_classes=num_classes, feature_extractor=feature_extractor, num_classes=num_classes, feature_extractor=feature_extractor,
...@@ -909,6 +964,9 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -909,6 +964,9 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_logits = crop_masks_within_boxes( mask_logits = crop_masks_within_boxes(
mask_logits, boxes, self._deepmac_params.postprocess_crop_size) mask_logits, boxes, self._deepmac_params.postprocess_crop_size)
if self._deepmac_params.network_type == 'embedding_distance_probability':
masks_prob = mask_logits
else:
masks_prob = tf.nn.sigmoid(mask_logits) masks_prob = tf.nn.sigmoid(mask_logits)
return masks_prob return masks_prob
......
...@@ -61,7 +61,10 @@ class MockMaskNet(tf.keras.layers.Layer): ...@@ -61,7 +61,10 @@ class MockMaskNet(tf.keras.layers.Layer):
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False, def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
mask_num_subsamples=-1): use_instance_embedding=True, mask_num_subsamples=-1,
network_type='hourglass10', use_xy=True,
pixel_embedding_dim=2,
dice_loss_prediction_probability=False):
"""Builds the DeepMAC meta architecture.""" """Builds the DeepMAC meta architecture."""
feature_extractor = DummyFeatureExtractor( feature_extractor = DummyFeatureExtractor(
...@@ -84,7 +87,9 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False, ...@@ -84,7 +87,9 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
use_labeled_classes=False) use_labeled_classes=False)
if use_dice_loss: if use_dice_loss:
classification_loss = losses.WeightedDiceClassificationLoss(False) classification_loss = losses.WeightedDiceClassificationLoss(
squared_normalization=False,
is_prediction_probability=dice_loss_prediction_probability)
else: else:
classification_loss = losses.WeightedSigmoidClassificationLoss() classification_loss = losses.WeightedSigmoidClassificationLoss()
...@@ -92,13 +97,13 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False, ...@@ -92,13 +97,13 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
classification_loss=classification_loss, classification_loss=classification_loss,
dim=8, dim=8,
task_loss_weight=1.0, task_loss_weight=1.0,
pixel_embedding_dim=2, pixel_embedding_dim=pixel_embedding_dim,
allowed_masked_classes_ids=[], allowed_masked_classes_ids=[],
mask_size=16, mask_size=16,
mask_num_subsamples=mask_num_subsamples, mask_num_subsamples=mask_num_subsamples,
use_xy=True, use_xy=use_xy,
network_type='hourglass10', network_type=network_type,
use_instance_embedding=True, use_instance_embedding=use_instance_embedding,
num_init_channels=8, num_init_channels=8,
predict_full_resolution_masks=predict_full_resolution_masks, predict_full_resolution_masks=predict_full_resolution_masks,
postprocess_crop_size=128, postprocess_crop_size=128,
...@@ -125,7 +130,7 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False, ...@@ -125,7 +130,7 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACUtilsTest(tf.test.TestCase): class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
def test_subsample_trivial(self): def test_subsample_trivial(self):
"""Test subsampling masks.""" """Test subsampling masks."""
...@@ -169,12 +174,22 @@ class DeepMACUtilsTest(tf.test.TestCase): ...@@ -169,12 +174,22 @@ class DeepMACUtilsTest(tf.test.TestCase):
features, boxes, 32) features, boxes, 32)
self.assertEqual(output.shape, (5, 32, 32, 7)) self.assertEqual(output.shape, (5, 32, 32, 7))
def test_embedding_distance_prob_shape(self):
dist = deepmac_meta_arch.embedding_distance_to_probability(
tf.ones((4, 32, 32, 8)), tf.zeros((4, 32, 32, 8)))
self.assertEqual(dist.shape, (4, 32, 32, 1))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @parameterized.parameters([1e-20, 1e20])
class DeepMACMetaArchTest(tf.test.TestCase): def test_embedding_distance_prob_value(self, value):
dist = deepmac_meta_arch.embedding_distance_to_probability(
tf.zeros((1, 1, 1, 8)), value + tf.zeros((1, 1, 1, 8))).numpy()
max_float = np.finfo(dist.dtype).max
self.assertLess(dist.max(), max_float)
self.assertGreater(dist.max(), -max_float)
def setUp(self): # pylint:disable=g-missing-super-call
self.model = build_meta_arch() @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMaskHeadTest(tf.test.TestCase):
def test_mask_network(self): def test_mask_network(self):
net = deepmac_meta_arch.MaskHeadNetwork('hourglass10', 8) net = deepmac_meta_arch.MaskHeadNetwork('hourglass10', 8)
...@@ -203,6 +218,38 @@ class DeepMACMetaArchTest(tf.test.TestCase): ...@@ -203,6 +218,38 @@ class DeepMACMetaArchTest(tf.test.TestCase):
out = call_func(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True) out = call_func(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True)
self.assertEqual(out.shape, (2, 32, 32)) self.assertEqual(out.shape, (2, 32, 32))
def test_mask_network_embedding_distance_zero_dist(self):
net = deepmac_meta_arch.MaskHeadNetwork(
'embedding_distance_probability', num_init_channels=8,
use_instance_embedding=False)
call_func = tf.function(net.__call__)
out = call_func(tf.zeros((2, 7)), tf.zeros((2, 32, 32, 7)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
self.assertAllGreater(out.numpy(), -np.inf)
self.assertAllLess(out.numpy(), np.inf)
def test_mask_network_embedding_distance_small_dist(self):
net = deepmac_meta_arch.MaskHeadNetwork(
'embedding_distance_probability', num_init_channels=-1,
use_instance_embedding=False)
call_func = tf.function(net.__call__)
out = call_func(1e6 + tf.zeros((2, 7)),
tf.zeros((2, 32, 32, 7)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
self.assertAllGreater(out.numpy(), -np.inf)
self.assertAllLess(out.numpy(), np.inf)
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): # pylint:disable=g-missing-super-call
self.model = build_meta_arch()
def test_get_mask_head_input(self): def test_get_mask_head_input(self):
boxes = tf.constant([[0., 0., 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]], boxes = tf.constant([[0., 0., 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]],
...@@ -349,6 +396,37 @@ class DeepMACMetaArchTest(tf.test.TestCase): ...@@ -349,6 +396,37 @@ class DeepMACMetaArchTest(tf.test.TestCase):
prob = tf.nn.sigmoid(0.9).numpy() prob = tf.nn.sigmoid(0.9).numpy()
self.assertAllClose(masks, prob * np.ones((2, 3, 16, 16))) self.assertAllClose(masks, prob * np.ones((2, 3, 16, 16)))
def test_postprocess_emb_dist(self):
model = build_meta_arch(network_type='embedding_distance_probability',
use_instance_embedding=False,
use_xy=False, pixel_embedding_dim=8,
use_dice_loss=True,
dice_loss_prediction_probability=True)
boxes = np.zeros((2, 3, 4), dtype=np.float32)
boxes[:, :, [0, 2]] = 0.0
boxes[:, :, [1, 3]] = 8.0
boxes = tf.constant(boxes)
masks = model._postprocess_masks(
boxes, tf.zeros((2, 32, 32, 2)), tf.zeros((2, 32, 32, 2)))
self.assertEqual(masks.shape, (2, 3, 16, 16))
def test_postprocess_emb_dist_fullres(self):
model = build_meta_arch(network_type='embedding_distance_probability',
predict_full_resolution_masks=True,
use_instance_embedding=False,
pixel_embedding_dim=8, use_xy=False,
use_dice_loss=True,
dice_loss_prediction_probability=True)
boxes = np.zeros((2, 3, 4), dtype=np.float32)
boxes = tf.constant(boxes)
masks = model._postprocess_masks(
boxes, tf.zeros((2, 32, 32, 2)), tf.zeros((2, 32, 32, 2)))
self.assertEqual(masks.shape, (2, 3, 128, 128))
def test_postprocess_no_crop_resize_shape(self): def test_postprocess_no_crop_resize_shape(self):
model = build_meta_arch(predict_full_resolution_masks=True) model = build_meta_arch(predict_full_resolution_masks=True)
...@@ -494,7 +572,7 @@ class FullyConnectedMaskHeadTest(tf.test.TestCase): ...@@ -494,7 +572,7 @@ class FullyConnectedMaskHeadTest(tf.test.TestCase):
class ResNetMaskHeadTest(tf.test.TestCase, parameterized.TestCase): class ResNetMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(['resnet4', 'resnet8', 'resnet20']) @parameterized.parameters(['resnet4', 'resnet8', 'resnet20'])
def test_pass(self, name): def test_forward(self, name):
net = deepmac_meta_arch.ResNetMaskNetwork(name, 8) net = deepmac_meta_arch.ResNetMaskNetwork(name, 8)
out = net(tf.zeros((3, 32, 32, 16))) out = net(tf.zeros((3, 32, 32, 16)))
self.assertEqual(out.shape[:3], (3, 32, 32)) self.assertEqual(out.shape[:3], (3, 32, 32))
......
...@@ -231,6 +231,10 @@ message WeightedDiceClassificationLoss { ...@@ -231,6 +231,10 @@ message WeightedDiceClassificationLoss {
// If set, we square the probabilities in the denominator term used for // If set, we square the probabilities in the denominator term used for
// normalization. // normalization.
optional bool squared_normalization = 1 [default=false]; optional bool squared_normalization = 1 [default=false];
// Whether or not the input prediction to the loss function is a
// probability. If not, the input is to be interpreted as logit
optional bool is_prediction_probability = 2 [default=false];
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment