Commit c626177d authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

DeepMAC support for point supervised dataset.

PiperOrigin-RevId: 445217499
parent 97b97572
...@@ -103,12 +103,12 @@ can also be applied to Mask R-CNN or without any detector at all. ...@@ -103,12 +103,12 @@ can also be applied to Mask R-CNN or without any detector at all.
Please see links below for more details Please see links below for more details
* [DeepMAC documentation](g3doc/deepmac.md). * [DeepMAC documentation](g3doc/deepmac.md).
* [Mask RCNN code](https://github.com/tensorflow/models/tree/master/official/projects/deepmac_maskrcnn) * [Mask RCNN code](https://github.com/tensorflow/models/tree/master/official/vision/beta/projects/deepmac_maskrcnn)
in TF Model garden code base. in TF Model garden code base.
* [DeepMAC Colab](./colab_tutorials/deepmac_colab.ipynb) that lets you run a * [DeepMAC Colab](./colab_tutorials/deepmac_colab.ipynb) that lets you run a
pre-trained DeepMAC model on user-specified boxes. Note that you are not pre-trained DeepMAC model on user-specified boxes. Note that you are not
restricted to COCO classes! restricted to COCO classes!
* Project website - [git.io/deepmac](https://google.github.io/deepmac/) * Project website - [git.io/deepmac](https://git.io/deepmac)
<b>Thanks to contributors</b>: Vighnesh Birodkar, Zhichao Lu, Siyang Li, <b>Thanks to contributors</b>: Vighnesh Birodkar, Zhichao Lu, Siyang Li,
Vivek Rathod, Jonathan Huang Vivek Rathod, Jonathan Huang
......
...@@ -22,6 +22,8 @@ the coordinates of the keypoint. ...@@ -22,6 +22,8 @@ the coordinates of the keypoint.
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.utils import shape_utils
def scale(keypoints, y_scale, x_scale, scope=None): def scale(keypoints, y_scale, x_scale, scope=None):
"""Scales keypoint coordinates in x and y dimensions. """Scales keypoint coordinates in x and y dimensions.
...@@ -345,7 +347,8 @@ def keypoint_weights_from_visibilities(keypoint_visibilities, ...@@ -345,7 +347,8 @@ def keypoint_weights_from_visibilities(keypoint_visibilities,
""" """
keypoint_visibilities.get_shape().assert_has_rank(2) keypoint_visibilities.get_shape().assert_has_rank(2)
if per_keypoint_weights is None: if per_keypoint_weights is None:
num_keypoints = keypoint_visibilities.shape.as_list()[1] num_keypoints = shape_utils.combined_static_and_dynamic_shape(
keypoint_visibilities)[1]
per_keypoint_weight_mult = tf.ones((1, num_keypoints,), dtype=tf.float32) per_keypoint_weight_mult = tf.ones((1, num_keypoints,), dtype=tf.float32)
else: else:
per_keypoint_weight_mult = tf.expand_dims(per_keypoint_weights, axis=0) per_keypoint_weight_mult = tf.expand_dims(per_keypoint_weights, axis=0)
......
...@@ -81,7 +81,7 @@ Resolution | Mask head | Config name | Mask m ...@@ -81,7 +81,7 @@ Resolution | Mask head | Config name | Mask m
* [Mask RCNN code](https://github.com/tensorflow/models/tree/master/official/vision/beta/projects/deepmac_maskrcnn) * [Mask RCNN code](https://github.com/tensorflow/models/tree/master/official/vision/beta/projects/deepmac_maskrcnn)
in TF Model garden code base. in TF Model garden code base.
* Project website - [git.io/deepmac](https://google.github.io/deepmac/) * Project website - [git.io/deepmac](https://git.io/deepmac)
## Citation ## Citation
......
...@@ -34,6 +34,7 @@ MASK_LOGITS_GT_BOXES = 'MASK_LOGITS_GT_BOXES' ...@@ -34,6 +34,7 @@ MASK_LOGITS_GT_BOXES = 'MASK_LOGITS_GT_BOXES'
DEEP_MASK_ESTIMATION = 'deep_mask_estimation' DEEP_MASK_ESTIMATION = 'deep_mask_estimation'
DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency' DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency'
DEEP_MASK_COLOR_CONSISTENCY = 'deep_mask_color_consistency' DEEP_MASK_COLOR_CONSISTENCY = 'deep_mask_color_consistency'
DEEP_MASK_POINTLY_SUPERVISED = 'deep_mask_pointly_supervised'
SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS = ( SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS = (
'SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS') 'SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS')
DEEP_MASK_AUGMENTED_SELF_SUPERVISION = 'deep_mask_augmented_self_supervision' DEEP_MASK_AUGMENTED_SELF_SUPERVISION = 'deep_mask_augmented_self_supervision'
...@@ -41,8 +42,10 @@ LOSS_KEY_PREFIX = center_net_meta_arch.LOSS_KEY_PREFIX ...@@ -41,8 +42,10 @@ LOSS_KEY_PREFIX = center_net_meta_arch.LOSS_KEY_PREFIX
NEIGHBORS_2D = [[-1, -1], [-1, 0], [-1, 1], NEIGHBORS_2D = [[-1, -1], [-1, 0], [-1, 1],
[0, -1], [0, 1], [0, -1], [0, 1],
[1, -1], [1, 0], [1, 1]] [1, -1], [1, 0], [1, 1]]
WEAK_LOSSES = [DEEP_MASK_BOX_CONSISTENCY, DEEP_MASK_COLOR_CONSISTENCY, WEAK_LOSSES = [DEEP_MASK_BOX_CONSISTENCY, DEEP_MASK_COLOR_CONSISTENCY,
DEEP_MASK_AUGMENTED_SELF_SUPERVISION] DEEP_MASK_AUGMENTED_SELF_SUPERVISION,
DEEP_MASK_POINTLY_SUPERVISED]
MASK_LOSSES = WEAK_LOSSES + [DEEP_MASK_ESTIMATION] MASK_LOSSES = WEAK_LOSSES + [DEEP_MASK_ESTIMATION]
...@@ -64,7 +67,8 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [ ...@@ -64,7 +67,8 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [
'augmented_self_supervision_warmup_steps', 'augmented_self_supervision_warmup_steps',
'augmented_self_supervision_loss', 'augmented_self_supervision_loss',
'augmented_self_supervision_scale_min', 'augmented_self_supervision_scale_min',
'augmented_self_supervision_scale_max' 'augmented_self_supervision_scale_max',
'pointly_supervised_keypoint_loss_weight'
]) ])
...@@ -78,6 +82,8 @@ def _get_loss_weight(loss_name, config): ...@@ -78,6 +82,8 @@ def _get_loss_weight(loss_name, config):
return config.box_consistency_loss_weight return config.box_consistency_loss_weight
elif loss_name == DEEP_MASK_AUGMENTED_SELF_SUPERVISION: elif loss_name == DEEP_MASK_AUGMENTED_SELF_SUPERVISION:
return config.augmented_self_supervision_loss_weight return config.augmented_self_supervision_loss_weight
elif loss_name == DEEP_MASK_POINTLY_SUPERVISED:
return config.pointly_supervised_keypoint_loss_weight
else: else:
raise ValueError('Unknown loss - {}'.format(loss_name)) raise ValueError('Unknown loss - {}'.format(loss_name))
...@@ -1356,6 +1362,11 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1356,6 +1362,11 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
loss: A [batch_size, num_instances] shaped tensor with the loss for each loss: A [batch_size, num_instances] shaped tensor with the loss for each
instance. instance.
""" """
if mask_gt is None:
logging.info('No mask GT provided, mask loss is 0.')
return tf.zeros_like(boxes[:, :, 0])
batch_size, num_instances = tf.shape(boxes)[0], tf.shape(boxes)[1] batch_size, num_instances = tf.shape(boxes)[0], tf.shape(boxes)[1]
mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt) mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt)
...@@ -1572,9 +1583,86 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1572,9 +1583,86 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return loss return loss
def _compute_pointly_supervised_loss_from_keypoints(
self, mask_logits, keypoints_gt, keypoints_depth_gt):
"""Computes per-point mask loss from keypoints.
Args:
mask_logits: A [batch_size, num_instances, height, width] float tensor
denoting predicted masks.
keypoints_gt: A [batch_size, num_instances, num_keypoints, 2] float tensor
of normalize keypoint coordinates.
keypoints_depth_gt: A [batch_size, num_instances, num_keyponts] float
tensor of keypoint depths. We assume that +1 is foreground and -1
is background.
Returns:
loss: Pointly supervised loss with shape [batch_size, num_instances].
"""
if keypoints_gt is None:
logging.info(('Returning 0 pointly supervised loss because '
'keypoints are not given.'))
return tf.zeros(tf.shape(mask_logits)[:2])
if keypoints_depth_gt is None:
logging.info(('Returning 0 pointly supervised loss because '
'keypoint depths are not given.'))
return tf.zeros(tf.shape(mask_logits)[:2])
if not self._deepmac_params.predict_full_resolution_masks:
raise NotImplementedError(
'Pointly supervised loss not implemented with RoIAlign.')
num_keypoints = tf.shape(keypoints_gt)[2]
keypoints_nan = tf.math.is_nan(keypoints_gt)
keypoints_gt = tf.where(
keypoints_nan, tf.zeros_like(keypoints_gt), keypoints_gt)
weights = tf.cast(
tf.logical_not(tf.reduce_any(keypoints_nan, axis=3)), tf.float32)
height, width = tf.shape(mask_logits)[2], tf.shape(mask_logits)[3]
ky, kx = tf.unstack(keypoints_gt, axis=3)
height_f, width_f = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
ky = tf.clip_by_value(tf.cast(ky * height_f, tf.int32), 0, height - 1)
kx = tf.clip_by_value(tf.cast(kx * width_f, tf.int32), 0, width - 1)
keypoints_gt_int = tf.stack([ky, kx], axis=3)
mask_logits_flat, batch_size, num_instances = flatten_first2_dims(
mask_logits)
keypoints_gt_int_flat, _, _ = flatten_first2_dims(keypoints_gt_int)
keypoint_depths_flat, _, _ = flatten_first2_dims(keypoints_depth_gt)
weights_flat = tf.logical_not(
tf.reduce_any(keypoints_nan, axis=2))
weights_flat, _, _ = flatten_first2_dims(weights)
# TODO(vighneshb): Replace with bilinear interpolation
point_mask_logits = tf.gather_nd(
mask_logits_flat, keypoints_gt_int_flat, batch_dims=1)
point_mask_logits = tf.reshape(
point_mask_logits, [batch_size * num_instances, num_keypoints, 1])
labels = tf.cast(keypoint_depths_flat > 0.0, tf.float32)
labels = tf.reshape(
labels, [batch_size * num_instances, num_keypoints, 1])
weights_flat = tf.reshape(
weights_flat, [batch_size * num_instances, num_keypoints, 1])
loss = self._deepmac_params.classification_loss(
prediction_tensor=point_mask_logits, target_tensor=labels,
weights=weights_flat
)
loss = self._aggregate_classification_loss(
loss, gt=labels, pred=point_mask_logits, method='normalize_auto')
return tf.reshape(loss, [batch_size, num_instances])
def _compute_deepmac_losses( def _compute_deepmac_losses(
self, boxes, masks_logits, masks_gt, image, self, boxes, masks_logits, masks_gt, image,
self_supervised_masks_logits=None): self_supervised_masks_logits=None, keypoints_gt=None,
keypoints_depth_gt=None):
"""Returns the mask loss per instance. """Returns the mask loss per instance.
Args: Args:
...@@ -1584,19 +1672,28 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1584,19 +1672,28 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
float tensor containing the instance mask predictions in their logit float tensor containing the instance mask predictions in their logit
form. form.
masks_gt: A [batch_size, num_instances, output_height, output_width] float masks_gt: A [batch_size, num_instances, output_height, output_width] float
tensor containing the groundtruth masks. tensor containing the groundtruth masks. If masks_gt is None,
DEEP_MASK_ESTIMATION is filled with 0s.
image: [batch_size, output_height, output_width, channels] float tensor image: [batch_size, output_height, output_width, channels] float tensor
denoting the input image. denoting the input image.
self_supervised_masks_logits: Optional self-supervised mask logits to self_supervised_masks_logits: Optional self-supervised mask logits to
compare against of same shape as mask_logits. compare against of same shape as mask_logits.
keypoints_gt: A float tensor of shape
[batch_size, num_instances, num_keypoints, 2], representing the points
where we have mask supervision.
keypoints_depth_gt: A float tensor of shape
[batch_size, num_instances, num_keypoints] of keypoint depths which
indicate the mask label at the keypoint locations. depth=+1 is
foreground and depth=-1 is background.
Returns: Returns:
mask_prediction_loss: A [batch_size, num_instances] shaped float tensor tensor_dict: A dictionary with 4 keys, each mapping to a tensor of shape
containing the mask loss for each instance in the batch. [batch_size, num_instances]. The 4 keys are:
box_consistency_loss: A [batch_size, num_instances] shaped float tensor - DEEP_MASK_ESTIMATION
containing the box consistency loss for each instance in the batch. - DEEP_MASK_BOX_CONSISTENCY
box_consistency_loss: A [batch_size, num_instances] shaped float tensor - DEEP_MASK_COLOR_CONSISTENCY
containing the color consistency loss in the batch. - DEEP_MASK_AUGMENTED_SELF_SUPERVISION
- DEEP_MASK_POINTLY_SUPERVISED
""" """
if tf.keras.backend.learning_phase(): if tf.keras.backend.learning_phase():
...@@ -1611,11 +1708,11 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1611,11 +1708,11 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
else: else:
boxes_for_crop = boxes boxes_for_crop = boxes
mask_gt = self._get_groundtruth_mask_output( if masks_gt is not None:
boxes_for_crop, masks_gt) masks_gt = self._get_groundtruth_mask_output(
boxes_for_crop, masks_gt)
mask_prediction_loss = self._compute_mask_prediction_loss( mask_prediction_loss = self._compute_mask_prediction_loss(
boxes_for_crop, masks_logits, mask_gt) boxes_for_crop, masks_logits, masks_gt)
box_consistency_loss = self._compute_box_consistency_loss( box_consistency_loss = self._compute_box_consistency_loss(
boxes, boxes_for_crop, masks_logits) boxes, boxes_for_crop, masks_logits)
...@@ -1627,11 +1724,16 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1627,11 +1724,16 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
masks_logits, self_supervised_masks_logits, boxes, masks_logits, self_supervised_masks_logits, boxes,
) )
pointly_supervised_loss = (
self._compute_pointly_supervised_loss_from_keypoints(
masks_logits, keypoints_gt, keypoints_depth_gt))
return { return {
DEEP_MASK_ESTIMATION: mask_prediction_loss, DEEP_MASK_ESTIMATION: mask_prediction_loss,
DEEP_MASK_BOX_CONSISTENCY: box_consistency_loss, DEEP_MASK_BOX_CONSISTENCY: box_consistency_loss,
DEEP_MASK_COLOR_CONSISTENCY: color_consistency_loss, DEEP_MASK_COLOR_CONSISTENCY: color_consistency_loss,
DEEP_MASK_AUGMENTED_SELF_SUPERVISION: self_supervised_loss DEEP_MASK_AUGMENTED_SELF_SUPERVISION: self_supervised_loss,
DEEP_MASK_POINTLY_SUPERVISED: pointly_supervised_loss,
} }
def _get_lab_image(self, preprocessed_image): def _get_lab_image(self, preprocessed_image):
...@@ -1644,6 +1746,13 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1644,6 +1746,13 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
' consistency loss is not supported in TF1.')) ' consistency loss is not supported in TF1.'))
return tfio.experimental.color.rgb_to_lab(raw_image) return tfio.experimental.color.rgb_to_lab(raw_image)
def _maybe_get_gt_batch(self, field):
"""Returns a batch of groundtruth tensors if available, else None."""
if self.groundtruth_has_field(field):
return _batch_gt_list(self.groundtruth_lists(field))
else:
return None
def _compute_masks_loss(self, prediction_dict): def _compute_masks_loss(self, prediction_dict):
"""Computes the mask loss. """Computes the mask loss.
...@@ -1671,16 +1780,12 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1671,16 +1780,12 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
prediction_dict['preprocessed_inputs'], (height, width)) prediction_dict['preprocessed_inputs'], (height, width))
image = self._get_lab_image(preprocessed_image) image = self._get_lab_image(preprocessed_image)
# Iterate over multiple preidctions by backbone (for hourglass length=2) gt_boxes = self._maybe_get_gt_batch(fields.BoxListFields.boxes)
gt_weights = self._maybe_get_gt_batch(fields.BoxListFields.weights)
gt_boxes = _batch_gt_list( gt_classes = self._maybe_get_gt_batch(fields.BoxListFields.classes)
self.groundtruth_lists(fields.BoxListFields.boxes)) gt_masks = self._maybe_get_gt_batch(fields.BoxListFields.masks)
gt_weights = _batch_gt_list( gt_keypoints = self._maybe_get_gt_batch(fields.BoxListFields.keypoints)
self.groundtruth_lists(fields.BoxListFields.weights)) gt_depths = self._maybe_get_gt_batch(fields.BoxListFields.keypoint_depths)
gt_masks = _batch_gt_list(
self.groundtruth_lists(fields.BoxListFields.masks))
gt_classes = _batch_gt_list(
self.groundtruth_lists(fields.BoxListFields.classes))
mask_logits_list = prediction_dict[MASK_LOGITS_GT_BOXES] mask_logits_list = prediction_dict[MASK_LOGITS_GT_BOXES]
self_supervised_mask_logits_list = prediction_dict.get( self_supervised_mask_logits_list = prediction_dict.get(
...@@ -1688,6 +1793,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1688,6 +1793,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
[None] * len(mask_logits_list)) [None] * len(mask_logits_list))
assert len(mask_logits_list) == len(self_supervised_mask_logits_list) assert len(mask_logits_list) == len(self_supervised_mask_logits_list)
# Iterate over multiple preidctions by backbone (for hourglass length=2)
for (mask_logits, self_supervised_mask_logits) in zip( for (mask_logits, self_supervised_mask_logits) in zip(
mask_logits_list, self_supervised_mask_logits_list): mask_logits_list, self_supervised_mask_logits_list):
...@@ -1698,9 +1804,11 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1698,9 +1804,11 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
sample_loss_dict = self._compute_deepmac_losses( sample_loss_dict = self._compute_deepmac_losses(
gt_boxes, mask_logits, gt_masks, image, gt_boxes, mask_logits, gt_masks, image,
self_supervised_masks_logits=self_supervised_mask_logits) self_supervised_masks_logits=self_supervised_mask_logits,
keypoints_gt=gt_keypoints, keypoints_depth_gt=gt_depths)
sample_loss_dict[DEEP_MASK_ESTIMATION] *= valid_mask_weights sample_loss_dict[DEEP_MASK_ESTIMATION] *= valid_mask_weights
for loss_name in WEAK_LOSSES: for loss_name in WEAK_LOSSES:
sample_loss_dict[loss_name] *= gt_weights sample_loss_dict[loss_name] *= gt_weights
......
...@@ -108,7 +108,8 @@ def build_meta_arch(**override_params): ...@@ -108,7 +108,8 @@ def build_meta_arch(**override_params):
augmented_self_supervision_warmup_steps=0, augmented_self_supervision_warmup_steps=0,
augmented_self_supervision_loss='loss_dice', augmented_self_supervision_loss='loss_dice',
augmented_self_supervision_scale_min=1.0, augmented_self_supervision_scale_min=1.0,
augmented_self_supervision_scale_max=1.0) augmented_self_supervision_scale_max=1.0,
pointly_supervised_keypoint_loss_weight=1.0)
params.update(override_params) params.update(override_params)
...@@ -197,6 +198,7 @@ DEEPMAC_PROTO_TEXT = """ ...@@ -197,6 +198,7 @@ DEEPMAC_PROTO_TEXT = """
augmented_self_supervision_flip_probability: 0.9 augmented_self_supervision_flip_probability: 0.9
augmented_self_supervision_scale_min: 0.42 augmented_self_supervision_scale_min: 0.42
augmented_self_supervision_scale_max: 1.42 augmented_self_supervision_scale_max: 1.42
pointly_supervised_keypoint_loss_weight: 0.13
""" """
...@@ -225,6 +227,8 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -225,6 +227,8 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
params.augmented_self_supervision_scale_min, 0.42) params.augmented_self_supervision_scale_min, 0.42)
self.assertAlmostEqual( self.assertAlmostEqual(
params.augmented_self_supervision_scale_max, 1.42) params.augmented_self_supervision_scale_max, 1.42)
self.assertAlmostEqual(
params.pointly_supervised_keypoint_loss_weight, 0.13)
def test_subsample_trivial(self): def test_subsample_trivial(self):
"""Test subsampling masks.""" """Test subsampling masks."""
...@@ -1440,9 +1444,12 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1440,9 +1444,12 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
loss_at_100[loss_key].numpy()) loss_at_100[loss_key].numpy())
def test_loss_keys(self): def test_loss_keys(self):
model = build_meta_arch(use_dice_loss=True, model = build_meta_arch(
augmented_self_supervision_loss_weight=1.0, use_dice_loss=True,
augmented_self_supervision_max_translation=0.5) augmented_self_supervision_loss_weight=1.0,
augmented_self_supervision_max_translation=0.5,
predict_full_resolution_masks=True)
prediction = { prediction = {
'preprocessed_inputs': tf.random.normal((3, 32, 32, 3)), 'preprocessed_inputs': tf.random.normal((3, 32, 32, 3)),
'MASK_LOGITS_GT_BOXES': [tf.random.normal((3, 5, 8, 8))] * 2, 'MASK_LOGITS_GT_BOXES': [tf.random.normal((3, 5, 8, 8))] * 2,
...@@ -1457,7 +1464,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1457,7 +1464,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)] * 3, tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)] * 3,
groundtruth_classes_list=[tf.one_hot([1, 0, 1, 1, 1], depth=6)] * 3, groundtruth_classes_list=[tf.one_hot([1, 0, 1, 1, 1], depth=6)] * 3,
groundtruth_weights_list=[tf.ones(5)] * 3, groundtruth_weights_list=[tf.ones(5)] * 3,
groundtruth_masks_list=[tf.ones((5, 32, 32))] * 3) groundtruth_masks_list=[tf.ones((5, 32, 32))] * 3,
groundtruth_keypoints_list=[tf.zeros((5, 10, 2))] * 3,
groundtruth_keypoint_depths_list=[tf.zeros((5, 10))] * 3)
loss = model.loss(prediction, tf.constant([[32, 32, 3.0]])) loss = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0) self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0)
...@@ -1495,11 +1504,15 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1495,11 +1504,15 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
classes = [tf.one_hot([1, 0, 1, 1, 1], depth=6)] classes = [tf.one_hot([1, 0, 1, 1, 1], depth=6)]
weights = [tf.ones(5)] weights = [tf.ones(5)]
masks = [tf.ones((5, 32, 32))] masks = [tf.ones((5, 32, 32))]
keypoints = [tf.zeros((5, 10, 2))]
keypoint_depths = [tf.ones((5, 10))]
model.provide_groundtruth( model.provide_groundtruth(
groundtruth_boxes_list=boxes, groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes, groundtruth_classes_list=classes,
groundtruth_weights_list=weights, groundtruth_weights_list=weights,
groundtruth_masks_list=masks) groundtruth_masks_list=masks,
groundtruth_keypoints_list=keypoints,
groundtruth_keypoint_depths_list=keypoint_depths)
loss = model.loss(prediction, tf.constant([[32, 32, 3.0]])) loss = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0) self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0)
...@@ -1513,7 +1526,8 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1513,7 +1526,8 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
deepmac_meta_arch.DEEP_MASK_BOX_CONSISTENCY: rng.uniform(1, 5), deepmac_meta_arch.DEEP_MASK_BOX_CONSISTENCY: rng.uniform(1, 5),
deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY: rng.uniform(1, 5), deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY: rng.uniform(1, 5),
deepmac_meta_arch.DEEP_MASK_AUGMENTED_SELF_SUPERVISION: ( deepmac_meta_arch.DEEP_MASK_AUGMENTED_SELF_SUPERVISION: (
rng.uniform(1, 5)) rng.uniform(1, 5)),
deepmac_meta_arch.DEEP_MASK_POINTLY_SUPERVISED: rng.uniform(1, 5)
} }
weighted_model = build_meta_arch( weighted_model = build_meta_arch(
...@@ -1531,14 +1545,18 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1531,14 +1545,18 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
loss_weights[deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY]), loss_weights[deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY]),
augmented_self_supervision_loss_weight=( augmented_self_supervision_loss_weight=(
loss_weights[deepmac_meta_arch.DEEP_MASK_AUGMENTED_SELF_SUPERVISION] loss_weights[deepmac_meta_arch.DEEP_MASK_AUGMENTED_SELF_SUPERVISION]
) ),
pointly_supervised_keypoint_loss_weight=(
loss_weights[deepmac_meta_arch.DEEP_MASK_POINTLY_SUPERVISED])
) )
weighted_model.provide_groundtruth( weighted_model.provide_groundtruth(
groundtruth_boxes_list=boxes, groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes, groundtruth_classes_list=classes,
groundtruth_weights_list=weights, groundtruth_weights_list=weights,
groundtruth_masks_list=masks) groundtruth_masks_list=masks,
groundtruth_keypoints_list=keypoints,
groundtruth_keypoint_depths_list=keypoint_depths)
weighted_loss = weighted_model.loss(prediction, tf.constant([[32, 32, 3]])) weighted_loss = weighted_model.loss(prediction, tf.constant([[32, 32, 3]]))
for mask_loss in deepmac_meta_arch.MASK_LOSSES: for mask_loss in deepmac_meta_arch.MASK_LOSSES:
...@@ -1613,6 +1631,36 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1613,6 +1631,36 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
self.assertAlmostEqual(loss_at_20[loss_key].numpy(), self.assertAlmostEqual(loss_at_20[loss_key].numpy(),
loss_at_100[loss_key].numpy()) loss_at_100[loss_key].numpy())
def test_pointly_supervised_loss(self):
tf.keras.backend.set_learning_phase(True)
model = build_meta_arch(
use_dice_loss=False,
predict_full_resolution_masks=True,
network_type='cond_inst1',
dim=9,
pixel_embedding_dim=8,
use_instance_embedding=False,
use_xy=False,
pointly_supervised_keypoint_loss_weight=1.0)
mask_logits = np.zeros((1, 1, 32, 32), dtype=np.float32)
keypoints = np.zeros((1, 1, 1, 2), dtype=np.float32)
keypoint_depths = np.zeros((1, 1, 1), dtype=np.float32)
keypoints[..., 0] = 0.5
keypoints[..., 1] = 0.5
keypoint_depths[..., 0] = 1.0
mask_logits[:, :, 16, 16] = 1.0
expected_loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=[[1.0]], labels=[[1.0]]
).numpy()
loss = model._compute_pointly_supervised_loss_from_keypoints(
mask_logits, keypoints, keypoint_depths)
self.assertEqual(loss.shape, (1, 1))
self.assertAllClose(expected_loss, loss)
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class FullyConnectedMaskHeadTest(tf.test.TestCase): class FullyConnectedMaskHeadTest(tf.test.TestCase):
......
...@@ -41,6 +41,7 @@ from object_detection.utils import visualization_utils as vutils ...@@ -41,6 +41,7 @@ from object_detection.utils import visualization_utils as vutils
MODEL_BUILD_UTIL_MAP = model_lib.MODEL_BUILD_UTIL_MAP MODEL_BUILD_UTIL_MAP = model_lib.MODEL_BUILD_UTIL_MAP
NUM_STEPS_PER_ITERATION = 100 NUM_STEPS_PER_ITERATION = 100
LOG_EVERY = 100
RESTORE_MAP_ERROR_TEMPLATE = ( RESTORE_MAP_ERROR_TEMPLATE = (
...@@ -536,8 +537,7 @@ def train_loop( ...@@ -536,8 +537,7 @@ def train_loop(
# Write the as-run pipeline config to disk. # Write the as-run pipeline config to disk.
if save_final_config: if save_final_config:
tf.logging.info('Saving pipeline config file to directory {}'.format( tf.logging.info('Saving pipeline config file to directory %s', model_dir)
model_dir))
pipeline_config_final = create_pipeline_proto_from_configs(configs) pipeline_config_final = create_pipeline_proto_from_configs(configs)
config_util.save_pipeline_config(pipeline_config_final, model_dir) config_util.save_pipeline_config(pipeline_config_final, model_dir)
...@@ -699,7 +699,7 @@ def train_loop( ...@@ -699,7 +699,7 @@ def train_loop(
for key, val in logged_dict.items(): for key, val in logged_dict.items():
tf.compat.v2.summary.scalar(key, val, step=global_step) tf.compat.v2.summary.scalar(key, val, step=global_step)
if global_step.value() - logged_step >= 100: if global_step.value() - logged_step >= LOG_EVERY:
logged_dict_np = {name: value.numpy() for name, value in logged_dict_np = {name: value.numpy() for name, value in
logged_dict.items()} logged_dict.items()}
tf.logging.info( tf.logging.info(
...@@ -1091,8 +1091,7 @@ def eval_continuously( ...@@ -1091,8 +1091,7 @@ def eval_continuously(
configs = merge_external_params_with_configs( configs = merge_external_params_with_configs(
configs, None, kwargs_dict=kwargs) configs, None, kwargs_dict=kwargs)
if model_dir and save_final_config: if model_dir and save_final_config:
tf.logging.info('Saving pipeline config file to directory {}'.format( tf.logging.info('Saving pipeline config file to directory %s', model_dir)
model_dir))
pipeline_config_final = create_pipeline_proto_from_configs(configs) pipeline_config_final = create_pipeline_proto_from_configs(configs)
config_util.save_pipeline_config(pipeline_config_final, model_dir) config_util.save_pipeline_config(pipeline_config_final, model_dir)
...@@ -1104,11 +1103,11 @@ def eval_continuously( ...@@ -1104,11 +1103,11 @@ def eval_continuously(
eval_on_train_input_config.sample_1_of_n_examples = ( eval_on_train_input_config.sample_1_of_n_examples = (
sample_1_of_n_eval_on_train_examples) sample_1_of_n_eval_on_train_examples)
if override_eval_num_epochs and eval_on_train_input_config.num_epochs != 1: if override_eval_num_epochs and eval_on_train_input_config.num_epochs != 1:
tf.logging.warning('Expected number of evaluation epochs is 1, but ' tf.logging.warning(
'instead encountered `eval_on_train_input_config' ('Expected number of evaluation epochs is 1, but '
'.num_epochs` = ' 'instead encountered `eval_on_train_input_config'
'{}. Overwriting `num_epochs` to 1.'.format( '.num_epochs` = %d. Overwriting `num_epochs` to 1.'),
eval_on_train_input_config.num_epochs)) eval_on_train_input_config.num_epochs)
eval_on_train_input_config.num_epochs = 1 eval_on_train_input_config.num_epochs = 1
if kwargs['use_bfloat16']: if kwargs['use_bfloat16']:
......
...@@ -403,7 +403,7 @@ message CenterNet { ...@@ -403,7 +403,7 @@ message CenterNet {
// Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613 // Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613
// Next ID 33 // Next ID 34
message DeepMACMaskEstimation { message DeepMACMaskEstimation {
// The loss used for penalizing mask predictions. // The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1; optional ClassificationLoss classification_loss = 1;
...@@ -520,6 +520,17 @@ message CenterNet { ...@@ -520,6 +520,17 @@ message CenterNet {
optional float augmented_self_supervision_scale_min = 31 [default=1.0]; optional float augmented_self_supervision_scale_min = 31 [default=1.0];
optional float augmented_self_supervision_scale_max = 32 [default=1.0]; optional float augmented_self_supervision_scale_max = 32 [default=1.0];
// The loss weight for the pointly supervised loss as defined in the paper
// https://arxiv.org/abs/2104.06404
// We assume that point supervision is given through a keypoint dataset,
// where each keypoint represents a sampled point, and its depth indicates
// whether it is a foreground or background point.
// Depth = +1 is assumed to be foreground and
// Depth = -1 is assumed to be background.
optional float pointly_supervised_keypoint_loss_weight = 33 [default = 0.0];
} }
optional DeepMACMaskEstimation deepmac_mask_estimation = 14; optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment