"examples/vscode:/vscode.git/clone" did not exist on "7aa494b3bc8e1ce76412057de148702e468fa1a5"
Commit 3cea66d1 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Ignore loss computation in overlapping boxes for DeepMAC.

PiperOrigin-RevId: 452829947
parent a62ef994
......@@ -68,7 +68,8 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [
'augmented_self_supervision_loss',
'augmented_self_supervision_scale_min',
'augmented_self_supervision_scale_max',
'pointly_supervised_keypoint_loss_weight'
'pointly_supervised_keypoint_loss_weight',
'ignore_per_class_box_overlap'
])
......@@ -254,6 +255,36 @@ def filter_masked_classes(masked_class_ids, classes, weights, masks):
)
def per_instance_no_class_overlap(classes, boxes, height, width):
"""Returns 1s inside boxes but overlapping boxes of same class are zeroed out.
Args:
classes: A [batch_size, num_instances, num_classes] float tensor containing
the one-hot encoded classes.
boxes: A [batch_size, num_instances, 4] shaped float tensor of normalized
boxes.
height: int, height of the desired mask.
width: int, width of the desired mask.
Returns:
mask: A [batch_size, num_instances, height, width] float tensor of 0s and
1s.
"""
box_mask = fill_boxes(boxes, height, width)
per_class_box_mask = (
box_mask[:, :, tf.newaxis, :, :] *
classes[:, :, :, tf.newaxis, tf.newaxis])
per_class_instance_count = tf.reduce_sum(per_class_box_mask, axis=1)
per_class_valid_map = per_class_instance_count < 2
class_indices = tf.argmax(classes, axis=2)
per_instance_valid_map = tf.gather(
per_class_valid_map, class_indices, batch_dims=1)
return tf.cast(per_instance_valid_map, tf.float32)
def flatten_first2_dims(tensor):
"""Flatten first 2 dimensions of a tensor.
......@@ -1144,13 +1175,15 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
def predict(self, preprocessed_inputs, true_image_shapes):
prediction_dict = super(DeepMACMetaArch, self).predict(
preprocessed_inputs, true_image_shapes)
mask_logits = self._predict_mask_logits_from_gt_boxes(prediction_dict)
prediction_dict[MASK_LOGITS_GT_BOXES] = mask_logits
if self._deepmac_params.augmented_self_supervision_loss_weight > 0.0:
prediction_dict[SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS] = (
self._predict_deaugmented_mask_logits_on_augmented_inputs(
preprocessed_inputs, true_image_shapes))
if self.groundtruth_has_field(fields.BoxListFields.boxes):
mask_logits = self._predict_mask_logits_from_gt_boxes(prediction_dict)
prediction_dict[MASK_LOGITS_GT_BOXES] = mask_logits
if self._deepmac_params.augmented_self_supervision_loss_weight > 0.0:
prediction_dict[SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS] = (
self._predict_deaugmented_mask_logits_on_augmented_inputs(
preprocessed_inputs, true_image_shapes))
return prediction_dict
def _predict_deaugmented_mask_logits_on_augmented_inputs(
......@@ -1349,14 +1382,17 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
raise ValueError('Unknown loss aggregation - {}'.format(method))
def _compute_mask_prediction_loss(
self, boxes, mask_logits, mask_gt):
self, boxes, mask_logits, mask_gt, classes):
"""Compute the per-instance mask loss.
Args:
boxes: A [batch_size, num_instances, 4] float tensor of GT boxes.
mask_logits: A [batch_suze, num_instances, height, width] float tensor of
boxes: A [batch_size, num_instances, 4] float tensor of GT boxes in
normalized coordinates.
mask_logits: A [batch_size, num_instances, height, width] float tensor of
predicted masks
mask_gt: The groundtruth mask of same shape as mask_logits.
classes: A [batch_size, num_instances, num_classes] shaped tensor of
one-hot encoded classes.
Returns:
loss: A [batch_size, num_instances] shaped tensor with the loss for each
......@@ -1369,9 +1405,19 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
batch_size, num_instances = tf.shape(boxes)[0], tf.shape(boxes)[1]
mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt)
height, width = tf.shape(mask_logits)[2], tf.shape(mask_logits)[3]
if self._deepmac_params.ignore_per_class_box_overlap:
mask_logits *= per_instance_no_class_overlap(
classes, boxes, height, width)
height, wdith = tf.shape(mask_gt)[2], tf.shape(mask_gt)[3]
mask_logits *= per_instance_no_class_overlap(
classes, boxes, height, wdith)
mask_logits = tf.reshape(mask_logits, [batch_size * num_instances, -1, 1])
mask_gt = tf.reshape(mask_gt, [batch_size * num_instances, -1, 1])
loss = self._deepmac_params.classification_loss(
prediction_tensor=mask_logits,
target_tensor=mask_gt,
......@@ -1660,7 +1706,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return tf.reshape(loss, [batch_size, num_instances])
def _compute_deepmac_losses(
self, boxes, masks_logits, masks_gt, image,
self, boxes, masks_logits, masks_gt, classes, image,
self_supervised_masks_logits=None, keypoints_gt=None,
keypoints_depth_gt=None):
"""Returns the mask loss per instance.
......@@ -1674,6 +1720,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
masks_gt: A [batch_size, num_instances, output_height, output_width] float
tensor containing the groundtruth masks. If masks_gt is None,
DEEP_MASK_ESTIMATION is filled with 0s.
classes: A [batch_size, num_instances, num_classes] tensor of one-hot
encoded classes.
image: [batch_size, output_height, output_width, channels] float tensor
denoting the input image.
self_supervised_masks_logits: Optional self-supervised mask logits to
......@@ -1712,7 +1760,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
masks_gt = self._get_groundtruth_mask_output(
boxes_for_crop, masks_gt)
mask_prediction_loss = self._compute_mask_prediction_loss(
boxes_for_crop, masks_logits, masks_gt)
boxes_for_crop, masks_logits, masks_gt, classes)
box_consistency_loss = self._compute_box_consistency_loss(
boxes, boxes_for_crop, masks_logits)
......@@ -1803,7 +1851,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
gt_weights, gt_masks)
sample_loss_dict = self._compute_deepmac_losses(
gt_boxes, mask_logits, gt_masks, image,
boxes=gt_boxes, masks_logits=mask_logits, masks_gt=gt_masks,
classes=gt_classes, image=image,
self_supervised_masks_logits=self_supervised_mask_logits,
keypoints_gt=gt_keypoints, keypoints_depth_gt=gt_depths)
......
......@@ -109,7 +109,8 @@ def build_meta_arch(**override_params):
augmented_self_supervision_loss='loss_dice',
augmented_self_supervision_scale_min=1.0,
augmented_self_supervision_scale_max=1.0,
pointly_supervised_keypoint_loss_weight=1.0)
pointly_supervised_keypoint_loss_weight=1.0,
ignore_per_class_box_overlap=False)
params.update(override_params)
......@@ -199,6 +200,7 @@ DEEPMAC_PROTO_TEXT = """
augmented_self_supervision_scale_min: 0.42
augmented_self_supervision_scale_max: 1.42
pointly_supervised_keypoint_loss_weight: 0.13
ignore_per_class_box_overlap: true
"""
......@@ -229,6 +231,7 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
params.augmented_self_supervision_scale_max, 1.42)
self.assertAlmostEqual(
params.pointly_supervised_keypoint_loss_weight, 0.13)
self.assertTrue(params.ignore_per_class_box_overlap)
def test_subsample_trivial(self):
"""Test subsampling masks."""
......@@ -531,6 +534,18 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
expected_output = np.reshape(expected_output, (1, 1, 1, 1))
self.assertAllClose(expected_output, out)
def test_per_instance_no_class_overlap(self):
boxes = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.4, 0.4]],
[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]],
dtype=tf.float32)
classes = tf.constant([[[0, 1, 0], [0, 1, 0]], [[0, 1, 0], [1, 0, 0]]],
dtype=tf.float32)
output = deepmac_meta_arch.per_instance_no_class_overlap(
classes, boxes, 2, 2)
self.assertEqual(output.shape, (2, 2, 2, 2))
self.assertAllClose(output[1], np.ones((2, 2, 2)))
self.assertAllClose(output[0, 1], [[0., 1.0], [1.0, 1.0]])
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
......@@ -943,6 +958,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
def test_predict_self_supervised_deaugmented_mask_logits(self):
tf.keras.backend.set_learning_phase(True)
model = build_meta_arch(
augmented_self_supervision_loss_weight=1.0,
predict_full_resolution_masks=True)
......@@ -967,9 +983,10 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
masks[0, 0, :16, :16] = 1.0
masks[0, 1, 16:, 16:] = 1.0
masks_pred = tf.fill((1, 2, 32, 32), 0.9)
classes = tf.zeros((1, 2, 5))
loss_dict = model._compute_deepmac_losses(
boxes, masks_pred, masks, tf.zeros((1, 16, 16, 3)))
boxes, masks_pred, masks, classes, tf.zeros((1, 16, 16, 3)))
self.assertAllClose(
loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION],
np.zeros((1, 2)) - tf.math.log(tf.nn.sigmoid(0.9)))
......@@ -980,9 +997,10 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
boxes = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
masks = tf.ones((1, 2, 128, 128), dtype=tf.float32)
masks_pred = tf.fill((1, 2, 32, 32), 0.9)
classes = tf.zeros((1, 2, 5))
loss_dict = model._compute_deepmac_losses(
boxes, masks_pred, masks, tf.zeros((1, 32, 32, 3)))
boxes, masks_pred, masks, classes, tf.zeros((1, 32, 32, 3)))
self.assertAllClose(
loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION],
np.zeros((1, 2)) - tf.math.log(tf.nn.sigmoid(0.9)))
......@@ -995,9 +1013,10 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
masks = np.ones((1, 2, 128, 128), dtype=np.float32)
masks = tf.constant(masks)
masks_pred = tf.fill((1, 2, 32, 32), 0.9)
classes = tf.zeros((1, 2, 5))
loss_dict = model._compute_deepmac_losses(
boxes, masks_pred, masks, tf.zeros((1, 32, 32, 3)))
boxes, masks_pred, masks, classes, tf.zeros((1, 32, 32, 3)))
pred = tf.nn.sigmoid(0.9)
expected = (1.0 - ((2.0 * pred) / (1.0 + pred)))
self.assertAllClose(loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION],
......@@ -1007,9 +1026,10 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
boxes = tf.zeros([1, 0, 4])
masks = tf.zeros([1, 0, 128, 128])
classes = tf.zeros((1, 2, 5))
loss_dict = self.model._compute_deepmac_losses(
boxes, masks, masks,
boxes, masks, masks, classes,
tf.zeros((1, 16, 16, 3)))
self.assertEqual(loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION].shape,
(1, 0))
......@@ -1476,6 +1496,33 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreater(loss['Loss/' + weak_loss], 0.0,
'{} was <= 0'.format(weak_loss))
def test_eval_loss_and_postprocess_keys(self):
model = build_meta_arch(
use_dice_loss=True,
augmented_self_supervision_loss_weight=1.0,
augmented_self_supervision_max_translation=0.5,
predict_full_resolution_masks=True)
true_image_shapes = tf.constant([[32, 32, 3]], dtype=tf.int32)
prediction_dict = model.predict(
tf.zeros((1, 32, 32, 3)), true_image_shapes)
output = model.postprocess(prediction_dict, true_image_shapes)
self.assertEqual(output['detection_boxes'].shape, (1, 5, 4))
self.assertEqual(output['detection_masks'].shape, (1, 5, 128, 128))
model.provide_groundtruth(
groundtruth_boxes_list=[
tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)] * 1,
groundtruth_classes_list=[tf.one_hot([1, 0, 1, 1, 1], depth=6)] * 1,
groundtruth_weights_list=[tf.ones(5)] * 1,
groundtruth_masks_list=[tf.ones((5, 32, 32))] * 1,
groundtruth_keypoints_list=[tf.zeros((5, 10, 2))] * 1,
groundtruth_keypoint_depths_list=[tf.zeros((5, 10))] * 1)
prediction_dict = model.predict(
tf.zeros((1, 32, 32, 3)), true_image_shapes)
model.loss(prediction_dict, true_image_shapes)
def test_loss_weight_response(self):
tf.random.set_seed(12)
model = build_meta_arch(
......@@ -1661,6 +1708,30 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(loss.shape, (1, 1))
self.assertAllClose(expected_loss, loss)
def test_ignore_per_class_box_overlap(self):
tf.keras.backend.set_learning_phase(True)
model = build_meta_arch(
use_dice_loss=False,
predict_full_resolution_masks=True,
network_type='cond_inst1',
dim=9,
pixel_embedding_dim=8,
use_instance_embedding=False,
use_xy=False,
pointly_supervised_keypoint_loss_weight=1.0,
ignore_per_class_box_overlap=True)
self.assertTrue(model._deepmac_params.ignore_per_class_box_overlap)
mask_logits = tf.zeros((2, 3, 16, 16))
mask_gt = tf.zeros((2, 3, 32, 32))
boxes = tf.zeros((2, 3, 4))
classes = tf.zeros((2, 3, 5))
loss = model._compute_mask_prediction_loss(
boxes, mask_logits, mask_gt, classes)
self.assertEqual(loss.shape, (2, 3))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class FullyConnectedMaskHeadTest(tf.test.TestCase):
......
......@@ -403,7 +403,7 @@ message CenterNet {
// Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613
// Next ID 34
// Next ID 35
message DeepMACMaskEstimation {
// The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1;
......@@ -531,6 +531,10 @@ message CenterNet {
// Depth = -1 is assumed to be background.
optional float pointly_supervised_keypoint_loss_weight = 33 [default = 0.0];
// When set, loss computation is ignored at pixels that fall within
// 2 boxes of the same class.
optional bool ignore_per_class_box_overlap = 34 [default = false];
}
optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment