Commit 1cbb7a7d authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Add box consistency loss in DeepMAC.

PiperOrigin-RevId: 385646058
parent 76640072
...@@ -26,6 +26,7 @@ from object_detection.utils import spatial_transform_ops ...@@ -26,6 +26,7 @@ from object_detection.utils import spatial_transform_ops
INSTANCE_EMBEDDING = 'INSTANCE_EMBEDDING' INSTANCE_EMBEDDING = 'INSTANCE_EMBEDDING'
PIXEL_EMBEDDING = 'PIXEL_EMBEDDING' PIXEL_EMBEDDING = 'PIXEL_EMBEDDING'
DEEP_MASK_ESTIMATION = 'deep_mask_estimation' DEEP_MASK_ESTIMATION = 'deep_mask_estimation'
DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency'
LOSS_KEY_PREFIX = center_net_meta_arch.LOSS_KEY_PREFIX LOSS_KEY_PREFIX = center_net_meta_arch.LOSS_KEY_PREFIX
...@@ -35,7 +36,7 @@ class DeepMACParams( ...@@ -35,7 +36,7 @@ class DeepMACParams(
'allowed_masked_classes_ids', 'mask_size', 'mask_num_subsamples', 'allowed_masked_classes_ids', 'mask_size', 'mask_num_subsamples',
'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels', 'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels',
'predict_full_resolution_masks', 'postprocess_crop_size', 'predict_full_resolution_masks', 'postprocess_crop_size',
'max_roi_jitter_ratio', 'roi_jitter_mode' 'max_roi_jitter_ratio', 'roi_jitter_mode', 'box_consistency_loss_weight'
])): ])):
"""Class holding the DeepMAC network configutration.""" """Class holding the DeepMAC network configutration."""
...@@ -46,7 +47,7 @@ class DeepMACParams( ...@@ -46,7 +47,7 @@ class DeepMACParams(
mask_num_subsamples, use_xy, network_type, use_instance_embedding, mask_num_subsamples, use_xy, network_type, use_instance_embedding,
num_init_channels, predict_full_resolution_masks, num_init_channels, predict_full_resolution_masks,
postprocess_crop_size, max_roi_jitter_ratio, postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode): roi_jitter_mode, box_consistency_loss_weight):
return super(DeepMACParams, return super(DeepMACParams,
cls).__new__(cls, classification_loss, dim, cls).__new__(cls, classification_loss, dim,
task_loss_weight, pixel_embedding_dim, task_loss_weight, pixel_embedding_dim,
...@@ -55,7 +56,7 @@ class DeepMACParams( ...@@ -55,7 +56,7 @@ class DeepMACParams(
use_instance_embedding, num_init_channels, use_instance_embedding, num_init_channels,
predict_full_resolution_masks, predict_full_resolution_masks,
postprocess_crop_size, max_roi_jitter_ratio, postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode) roi_jitter_mode, box_consistency_loss_weight)
def subsample_instances(classes, weights, boxes, masks, num_subsamples): def subsample_instances(classes, weights, boxes, masks, num_subsamples):
...@@ -206,6 +207,61 @@ def filter_masked_classes(masked_class_ids, classes, weights, masks): ...@@ -206,6 +207,61 @@ def filter_masked_classes(masked_class_ids, classes, weights, masks):
) )
def crop_and_resize_feature_map(features, boxes, size):
"""Crop and resize regions from a single feature map given a set of boxes.
Args:
features: A [H, W, C] float tensor.
boxes: A [N, 4] tensor of norrmalized boxes.
size: int, the size of the output features.
Returns:
per_box_features: A [N, size, size, C] tensor of cropped and resized
features.
"""
return spatial_transform_ops.matmul_crop_and_resize(
features[tf.newaxis], boxes[tf.newaxis], [size, size])[0]
def crop_and_resize_instance_masks(masks, boxes, mask_size):
"""Crop and resize each mask according to the given boxes.
Args:
masks: A [N, H, W] float tensor.
boxes: A [N, 4] float tensor of normalized boxes.
mask_size: int, the size of the output masks.
Returns:
masks: A [N, mask_size, mask_size] float tensor of cropped and resized
instance masks.
"""
cropped_masks = spatial_transform_ops.matmul_crop_and_resize(
masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :],
[mask_size, mask_size])
cropped_masks = tf.squeeze(cropped_masks, axis=[1, 4])
return cropped_masks
def fill_boxes(boxes, height, width):
"""Fills the area included in the box."""
blist = box_list.BoxList(boxes)
blist = box_list_ops.to_absolute_coordinates(blist, height, width)
boxes = blist.get()
ymin, xmin, ymax, xmax = tf.unstack(
boxes[:, tf.newaxis, tf.newaxis, :], 4, axis=3)
ygrid, xgrid = tf.meshgrid(tf.range(height), tf.range(width), indexing='ij')
ygrid, xgrid = tf.cast(ygrid, tf.float32), tf.cast(xgrid, tf.float32)
ygrid, xgrid = ygrid[tf.newaxis, :, :], xgrid[tf.newaxis, :, :]
filled_boxes = tf.logical_and(
tf.logical_and(ygrid >= ymin, ygrid <= ymax),
tf.logical_and(xgrid >= xmin, xgrid <= xmax))
return tf.cast(filled_boxes, tf.float32)
class ResNetMaskNetwork(tf.keras.layers.Layer): class ResNetMaskNetwork(tf.keras.layers.Layer):
"""A small wrapper around ResNet blocks to predict masks.""" """A small wrapper around ResNet blocks to predict masks."""
...@@ -379,7 +435,8 @@ def deepmac_proto_to_params(deepmac_config): ...@@ -379,7 +435,8 @@ def deepmac_proto_to_params(deepmac_config):
deepmac_config.predict_full_resolution_masks, deepmac_config.predict_full_resolution_masks,
postprocess_crop_size=deepmac_config.postprocess_crop_size, postprocess_crop_size=deepmac_config.postprocess_crop_size,
max_roi_jitter_ratio=deepmac_config.max_roi_jitter_ratio, max_roi_jitter_ratio=deepmac_config.max_roi_jitter_ratio,
roi_jitter_mode=jitter_mode roi_jitter_mode=jitter_mode,
box_consistency_loss_weight=deepmac_config.box_consistency_loss_weight
) )
...@@ -402,6 +459,13 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -402,6 +459,13 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
"""Constructs the super class with object center & detection params only.""" """Constructs the super class with object center & detection params only."""
self._deepmac_params = deepmac_params self._deepmac_params = deepmac_params
if (self._deepmac_params.predict_full_resolution_masks and
self._deepmac_params.max_roi_jitter_ratio > 0.0):
raise ValueError('Jittering is not supported for full res masks.')
if self._deepmac_params.mask_num_subsamples > 0:
raise ValueError('Subsampling masks is currently not supported.')
super(DeepMACMetaArch, self).__init__( super(DeepMACMetaArch, self).__init__(
is_training=is_training, add_summaries=add_summaries, is_training=is_training, add_summaries=add_summaries,
num_classes=num_classes, feature_extractor=feature_extractor, num_classes=num_classes, feature_extractor=feature_extractor,
...@@ -462,21 +526,34 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -462,21 +526,34 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
pixel_embedding = pixel_embedding[tf.newaxis, :, :, :] pixel_embedding = pixel_embedding[tf.newaxis, :, :, :]
pixel_embeddings_processed = tf.tile(pixel_embedding, pixel_embeddings_processed = tf.tile(pixel_embedding,
[num_instances, 1, 1, 1]) [num_instances, 1, 1, 1])
image_shape = tf.shape(pixel_embeddings_processed)
image_height, image_width = image_shape[1], image_shape[2]
y_grid, x_grid = tf.meshgrid(tf.linspace(0.0, 1.0, image_height),
tf.linspace(0.0, 1.0, image_width),
indexing='ij')
blist = box_list.BoxList(boxes)
ycenter, xcenter, _, _ = blist.get_center_coordinates_and_sizes()
y_grid = y_grid[tf.newaxis, :, :]
x_grid = x_grid[tf.newaxis, :, :]
y_grid -= ycenter[:, tf.newaxis, tf.newaxis]
x_grid -= xcenter[:, tf.newaxis, tf.newaxis]
coords = tf.stack([y_grid, x_grid], axis=3)
else: else:
# TODO(vighneshb) Explore multilevel_roi_align and align_corners=False. # TODO(vighneshb) Explore multilevel_roi_align and align_corners=False.
pixel_embeddings_cropped = spatial_transform_ops.matmul_crop_and_resize( pixel_embeddings_processed = crop_and_resize_feature_map(
pixel_embedding[tf.newaxis], boxes[tf.newaxis], pixel_embedding, boxes, mask_size)
[mask_size, mask_size]) mask_shape = tf.shape(pixel_embeddings_processed)
pixel_embeddings_processed = pixel_embeddings_cropped[0] mask_height, mask_width = mask_shape[1], mask_shape[2]
y_grid, x_grid = tf.meshgrid(tf.linspace(-1.0, 1.0, mask_height),
mask_shape = tf.shape(pixel_embeddings_processed) tf.linspace(-1.0, 1.0, mask_width),
mask_height, mask_width = mask_shape[1], mask_shape[2] indexing='ij')
y_grid, x_grid = tf.meshgrid(tf.linspace(-1.0, 1.0, mask_height),
tf.linspace(-1.0, 1.0, mask_width), coords = tf.stack([y_grid, x_grid], axis=2)
indexing='ij') coords = coords[tf.newaxis, :, :, :]
coords = tf.stack([y_grid, x_grid], axis=2) coords = tf.tile(coords, [num_instances, 1, 1, 1])
coords = coords[tf.newaxis, :, :, :]
coords = tf.tile(coords, [num_instances, 1, 1, 1])
if self._deepmac_params.use_xy: if self._deepmac_params.use_xy:
return tf.concat([coords, pixel_embeddings_processed], axis=3) return tf.concat([coords, pixel_embeddings_processed], axis=3)
...@@ -528,11 +605,9 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -528,11 +605,9 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
if self._deepmac_params.predict_full_resolution_masks: if self._deepmac_params.predict_full_resolution_masks:
return masks return masks
else: else:
cropped_masks = spatial_transform_ops.matmul_crop_and_resize( cropped_masks = crop_and_resize_instance_masks(
masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :], masks, boxes, mask_size)
[mask_size, mask_size])
cropped_masks = tf.stop_gradient(cropped_masks) cropped_masks = tf.stop_gradient(cropped_masks)
cropped_masks = tf.squeeze(cropped_masks, axis=[1, 4])
# TODO(vighneshb) should we discretize masks? # TODO(vighneshb) should we discretize masks?
return cropped_masks return cropped_masks
...@@ -543,7 +618,64 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -543,7 +618,64 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return resize_instance_masks(logits, (height, width)) return resize_instance_masks(logits, (height, width))
def _compute_per_instance_mask_loss( def _compute_per_instance_mask_prediction_loss(
self, boxes, mask_logits, mask_gt):
num_instances = tf.shape(boxes)[0]
mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt)
mask_logits = tf.reshape(mask_logits, [num_instances, -1, 1])
mask_gt = tf.reshape(mask_gt, [num_instances, -1, 1])
loss = self._deepmac_params.classification_loss(
prediction_tensor=mask_logits,
target_tensor=mask_gt,
weights=tf.ones_like(mask_logits))
# TODO(vighneshb) Make this configurable via config.
# Skip normalization for dice loss because the denominator term already
# does normalization.
if isinstance(self._deepmac_params.classification_loss,
losses.WeightedDiceClassificationLoss):
return tf.reduce_sum(loss, axis=1)
else:
return tf.reduce_mean(loss, axis=[1, 2])
def _compute_per_instance_box_consistency_loss(
self, boxes_gt, boxes_for_crop, mask_logits):
height, width = tf.shape(mask_logits)[1], tf.shape(mask_logits)[2]
filled_boxes = fill_boxes(boxes_gt, height, width)[:, :, :, tf.newaxis]
mask_logits = mask_logits[:, :, :, tf.newaxis]
if self._deepmac_params.predict_full_resolution_masks:
gt_crop = filled_boxes[:, :, :, 0]
pred_crop = mask_logits[:, :, :, 0]
else:
gt_crop = crop_and_resize_instance_masks(
filled_boxes, boxes_for_crop, self._deepmac_params.mask_size)
pred_crop = crop_and_resize_instance_masks(
mask_logits, boxes_for_crop, self._deepmac_params.mask_size)
loss = 0.0
for axis in [1, 2]:
pred_max = tf.reduce_max(pred_crop, axis=axis)[:, :, tf.newaxis]
gt_max = tf.reduce_max(gt_crop, axis=axis)[:, :, tf.newaxis]
axis_loss = self._deepmac_params.classification_loss(
prediction_tensor=pred_max,
target_tensor=gt_max,
weights=tf.ones_like(pred_max))
loss += axis_loss
# Skip normalization for dice loss because the denominator term already
# does normalization.
# TODO(vighneshb) Make this configurable via config.
if isinstance(self._deepmac_params.classification_loss,
losses.WeightedDiceClassificationLoss):
return tf.reduce_sum(loss, axis=1)
else:
return tf.reduce_mean(loss, axis=[1, 2])
def _compute_per_instance_deepmac_losses(
self, boxes, masks, instance_embedding, pixel_embedding): self, boxes, masks, instance_embedding, pixel_embedding):
"""Returns the mask loss per instance. """Returns the mask loss per instance.
...@@ -558,40 +690,36 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -558,40 +690,36 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
pixel_embedding_size] float tensor containing the per-pixel embeddings. pixel_embedding_size] float tensor containing the per-pixel embeddings.
Returns: Returns:
mask_loss: A [num_instances] shaped float tensor containing the mask_prediction_loss: A [num_instances] shaped float tensor containing the
mask loss for each instance. mask loss for each instance.
""" box_consistency_loss: A [num_instances] shaped float tensor containing
the box consistency loss for each instance.
num_instances = tf.shape(boxes)[0] """
if tf.keras.backend.learning_phase(): if tf.keras.backend.learning_phase():
boxes = preprocessor.random_jitter_boxes( boxes_for_crop = preprocessor.random_jitter_boxes(
boxes, self._deepmac_params.max_roi_jitter_ratio, boxes, self._deepmac_params.max_roi_jitter_ratio,
jitter_mode=self._deepmac_params.roi_jitter_mode) jitter_mode=self._deepmac_params.roi_jitter_mode)
else:
boxes_for_crop = boxes
mask_input = self._get_mask_head_input( mask_input = self._get_mask_head_input(
boxes, pixel_embedding) boxes_for_crop, pixel_embedding)
instance_embeddings = self._get_instance_embeddings( instance_embeddings = self._get_instance_embeddings(
boxes, instance_embedding) boxes_for_crop, instance_embedding)
mask_logits = self._mask_net( mask_logits = self._mask_net(
instance_embeddings, mask_input, instance_embeddings, mask_input,
training=tf.keras.backend.learning_phase()) training=tf.keras.backend.learning_phase())
mask_gt = self._get_groundtruth_mask_output(boxes, masks) mask_gt = self._get_groundtruth_mask_output(boxes_for_crop, masks)
mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt)
mask_logits = tf.reshape(mask_logits, [num_instances, -1, 1]) mask_prediction_loss = self._compute_per_instance_mask_prediction_loss(
mask_gt = tf.reshape(mask_gt, [num_instances, -1, 1]) boxes_for_crop, mask_logits, mask_gt)
loss = self._deepmac_params.classification_loss(
prediction_tensor=mask_logits,
target_tensor=mask_gt,
weights=tf.ones_like(mask_logits))
# TODO(vighneshb) Make this configurable via config. box_consistency_loss = self._compute_per_instance_box_consistency_loss(
if isinstance(self._deepmac_params.classification_loss, boxes, boxes_for_crop, mask_logits)
losses.WeightedDiceClassificationLoss):
return tf.reduce_sum(loss, axis=1) return mask_prediction_loss, box_consistency_loss
else:
return tf.reduce_mean(loss, axis=[1, 2])
def _compute_instance_masks_loss(self, prediction_dict): def _compute_instance_masks_loss(self, prediction_dict):
"""Computes the mask loss. """Computes the mask loss.
...@@ -603,7 +731,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -603,7 +731,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
[batch_size, height, width, embedding_size]. [batch_size, height, width, embedding_size].
Returns: Returns:
loss: float, the mask loss as a scalar. loss_dict: A dict mapping string (loss names) to scalar floats.
""" """
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes) gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights) gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
...@@ -613,7 +741,10 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -613,7 +741,10 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
allowed_masked_classes_ids = ( allowed_masked_classes_ids = (
self._deepmac_params.allowed_masked_classes_ids) self._deepmac_params.allowed_masked_classes_ids)
total_loss = 0.0 loss_dict = {
DEEP_MASK_ESTIMATION: 0.0,
DEEP_MASK_BOX_CONSISTENCY: 0.0
}
# Iterate over multiple preidctions by backbone (for hourglass length=2) # Iterate over multiple preidctions by backbone (for hourglass length=2)
for instance_pred, pixel_pred in zip( for instance_pred, pixel_pred in zip(
...@@ -625,24 +756,31 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -625,24 +756,31 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
for i, (boxes, weights, classes, masks) in enumerate( for i, (boxes, weights, classes, masks) in enumerate(
zip(gt_boxes_list, gt_weights_list, gt_classes_list, gt_masks_list)): zip(gt_boxes_list, gt_weights_list, gt_classes_list, gt_masks_list)):
_, weights, masks = filter_masked_classes(allowed_masked_classes_ids, # TODO(vighneshb) Add sub-sampling back if required.
classes, weights, masks) classes, valid_mask_weights, masks = filter_masked_classes(
num_subsample = self._deepmac_params.mask_num_subsamples allowed_masked_classes_ids, classes, weights, masks)
_, weights, boxes, masks = subsample_instances(
classes, weights, boxes, masks, num_subsample)
per_instance_loss = self._compute_per_instance_mask_loss( per_instance_mask_loss, per_instance_consistency_loss = (
boxes, masks, instance_pred[i], pixel_pred[i]) self._compute_per_instance_deepmac_losses(
per_instance_loss *= weights boxes, masks, instance_pred[i], pixel_pred[i]))
per_instance_mask_loss *= valid_mask_weights
per_instance_consistency_loss *= weights
num_instances = tf.maximum(tf.reduce_sum(weights), 1.0) num_instances = tf.maximum(tf.reduce_sum(weights), 1.0)
num_instances_allowed = tf.maximum(
tf.reduce_sum(valid_mask_weights), 1.0)
total_loss += tf.reduce_sum(per_instance_loss) / num_instances loss_dict[DEEP_MASK_ESTIMATION] += (
tf.reduce_sum(per_instance_mask_loss) / num_instances_allowed)
loss_dict[DEEP_MASK_BOX_CONSISTENCY] += (
tf.reduce_sum(per_instance_consistency_loss) / num_instances)
batch_size = len(gt_boxes_list) batch_size = len(gt_boxes_list)
num_predictions = len(prediction_dict[INSTANCE_EMBEDDING]) num_predictions = len(prediction_dict[INSTANCE_EMBEDDING])
return total_loss / float(batch_size * num_predictions) return dict((key, loss / float(batch_size * num_predictions))
for key, loss in loss_dict.items())
def loss(self, prediction_dict, true_image_shapes, scope=None): def loss(self, prediction_dict, true_image_shapes, scope=None):
...@@ -650,13 +788,19 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -650,13 +788,19 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
prediction_dict, true_image_shapes, scope) prediction_dict, true_image_shapes, scope)
if self._deepmac_params is not None: if self._deepmac_params is not None:
mask_loss = self._compute_instance_masks_loss( mask_loss_dict = self._compute_instance_masks_loss(
prediction_dict=prediction_dict) prediction_dict=prediction_dict)
key = LOSS_KEY_PREFIX + '/' + DEEP_MASK_ESTIMATION
losses_dict[key] = ( losses_dict[LOSS_KEY_PREFIX + '/' + DEEP_MASK_ESTIMATION] = (
self._deepmac_params.task_loss_weight * mask_loss self._deepmac_params.task_loss_weight * mask_loss_dict[
DEEP_MASK_ESTIMATION]
) )
if self._deepmac_params.box_consistency_loss_weight > 0.0:
losses_dict[LOSS_KEY_PREFIX + '/' + DEEP_MASK_BOX_CONSISTENCY] = (
self._deepmac_params.box_consistency_loss_weight * mask_loss_dict[
DEEP_MASK_BOX_CONSISTENCY]
)
return losses_dict return losses_dict
def postprocess(self, prediction_dict, true_image_shapes, **params): def postprocess(self, prediction_dict, true_image_shapes, **params):
......
...@@ -60,7 +60,8 @@ class MockMaskNet(tf.keras.layers.Layer): ...@@ -60,7 +60,8 @@ class MockMaskNet(tf.keras.layers.Layer):
return tf.zeros_like(pixel_embedding[:, :, :, 0]) + 0.9 return tf.zeros_like(pixel_embedding[:, :, :, 0]) + 0.9
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False): def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
mask_num_subsamples=-1):
"""Builds the DeepMAC meta architecture.""" """Builds the DeepMAC meta architecture."""
feature_extractor = DummyFeatureExtractor( feature_extractor = DummyFeatureExtractor(
...@@ -94,7 +95,7 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False): ...@@ -94,7 +95,7 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
pixel_embedding_dim=2, pixel_embedding_dim=2,
allowed_masked_classes_ids=[], allowed_masked_classes_ids=[],
mask_size=16, mask_size=16,
mask_num_subsamples=-1, mask_num_subsamples=mask_num_subsamples,
use_xy=True, use_xy=True,
network_type='hourglass10', network_type='hourglass10',
use_instance_embedding=True, use_instance_embedding=True,
...@@ -102,7 +103,8 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False): ...@@ -102,7 +103,8 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
predict_full_resolution_masks=predict_full_resolution_masks, predict_full_resolution_masks=predict_full_resolution_masks,
postprocess_crop_size=128, postprocess_crop_size=128,
max_roi_jitter_ratio=0.0, max_roi_jitter_ratio=0.0,
roi_jitter_mode='random' roi_jitter_mode='random',
box_consistency_loss_weight=1.0,
) )
object_detection_params = center_net_meta_arch.ObjectDetectionParams( object_detection_params = center_net_meta_arch.ObjectDetectionParams(
...@@ -140,6 +142,33 @@ class DeepMACUtilsTest(tf.test.TestCase): ...@@ -140,6 +142,33 @@ class DeepMACUtilsTest(tf.test.TestCase):
self.assertAllClose(result[2], boxes) self.assertAllClose(result[2], boxes)
self.assertAllClose(result[3], masks) self.assertAllClose(result[3], masks)
def test_fill_boxes(self):
boxes = tf.constant([[0., 0., 0.5, 0.5], [0.5, 0.5, 1.0, 1.0]])
filled_boxes = deepmac_meta_arch.fill_boxes(boxes, 32, 32)
expected = np.zeros((2, 32, 32))
expected[0, :17, :17] = 1.0
expected[1, 16:, 16:] = 1.0
self.assertAllClose(expected, filled_boxes.numpy(), rtol=1e-3)
def test_crop_and_resize_instance_masks(self):
boxes = tf.zeros((5, 4))
masks = tf.zeros((5, 128, 128))
output = deepmac_meta_arch.crop_and_resize_instance_masks(
masks, boxes, 32)
self.assertEqual(output.shape, (5, 32, 32))
def test_crop_and_resize_feature_map(self):
boxes = tf.zeros((5, 4))
features = tf.zeros((128, 128, 7))
output = deepmac_meta_arch.crop_and_resize_feature_map(
features, boxes, 32)
self.assertEqual(output.shape, (5, 32, 32, 7))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMetaArchTest(tf.test.TestCase): class DeepMACMetaArchTest(tf.test.TestCase):
...@@ -199,7 +228,7 @@ class DeepMACMetaArchTest(tf.test.TestCase): ...@@ -199,7 +228,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
def test_get_mask_head_input_no_crop_resize(self): def test_get_mask_head_input_no_crop_resize(self):
model = build_meta_arch(predict_full_resolution_masks=True) model = build_meta_arch(predict_full_resolution_masks=True)
boxes = tf.constant([[0., 0., 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], boxes = tf.constant([[0., 0., 1.0, 1.0], [0.0, 0.0, 0.5, 1.0]],
dtype=tf.float32) dtype=tf.float32)
pixel_embedding_np = np.random.randn(32, 32, 4).astype(np.float32) pixel_embedding_np = np.random.randn(32, 32, 4).astype(np.float32)
...@@ -208,12 +237,15 @@ class DeepMACMetaArchTest(tf.test.TestCase): ...@@ -208,12 +237,15 @@ class DeepMACMetaArchTest(tf.test.TestCase):
mask_inputs = model._get_mask_head_input(boxes, pixel_embedding) mask_inputs = model._get_mask_head_input(boxes, pixel_embedding)
self.assertEqual(mask_inputs.shape, (2, 32, 32, 6)) self.assertEqual(mask_inputs.shape, (2, 32, 32, 6))
y_grid, x_grid = tf.meshgrid(np.linspace(-1.0, 1.0, 32), y_grid, x_grid = tf.meshgrid(np.linspace(.0, 1.0, 32),
np.linspace(-1.0, 1.0, 32), indexing='ij') np.linspace(.0, 1.0, 32), indexing='ij')
ys = [0.5, 0.25]
xs = [0.5, 0.5]
for i in range(2): for i in range(2):
mask_input = mask_inputs[i] mask_input = mask_inputs[i]
self.assertAllClose(y_grid, mask_input[:, :, 0]) self.assertAllClose(y_grid - ys[i], mask_input[:, :, 0])
self.assertAllClose(x_grid, mask_input[:, :, 1]) self.assertAllClose(x_grid - xs[i], mask_input[:, :, 1])
pixel_embedding = mask_input[:, :, 2:] pixel_embedding = mask_input[:, :, 2:]
self.assertAllClose(pixel_embedding_np, pixel_embedding) self.assertAllClose(pixel_embedding_np, pixel_embedding)
...@@ -262,7 +294,7 @@ class DeepMACMetaArchTest(tf.test.TestCase): ...@@ -262,7 +294,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
masks[1, 16:, 16:] = 1.0 masks[1, 16:, 16:] = 1.0
masks = tf.constant(masks) masks = tf.constant(masks)
loss = model._compute_per_instance_mask_loss( loss, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2))) boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)))
self.assertAllClose( self.assertAllClose(
loss, np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9))) loss, np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9)))
...@@ -275,7 +307,7 @@ class DeepMACMetaArchTest(tf.test.TestCase): ...@@ -275,7 +307,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
masks = np.ones((2, 128, 128), dtype=np.float32) masks = np.ones((2, 128, 128), dtype=np.float32)
masks = tf.constant(masks) masks = tf.constant(masks)
loss = model._compute_per_instance_mask_loss( loss, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2))) boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)))
self.assertAllClose( self.assertAllClose(
loss, np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9))) loss, np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9)))
...@@ -289,7 +321,7 @@ class DeepMACMetaArchTest(tf.test.TestCase): ...@@ -289,7 +321,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
masks = np.ones((2, 128, 128), dtype=np.float32) masks = np.ones((2, 128, 128), dtype=np.float32)
masks = tf.constant(masks) masks = tf.constant(masks)
loss = model._compute_per_instance_mask_loss( loss, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2))) boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)))
pred = tf.nn.sigmoid(0.9) pred = tf.nn.sigmoid(0.9)
expected = (1.0 - ((2.0 * pred) / (1.0 + pred))) expected = (1.0 - ((2.0 * pred) / (1.0 + pred)))
...@@ -299,7 +331,7 @@ class DeepMACMetaArchTest(tf.test.TestCase): ...@@ -299,7 +331,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
boxes = tf.zeros([0, 4]) boxes = tf.zeros([0, 4])
masks = tf.zeros([0, 128, 128]) masks = tf.zeros([0, 128, 128])
loss = self.model._compute_per_instance_mask_loss( loss, _ = self.model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2))) boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)))
self.assertEqual(loss.shape, (0,)) self.assertEqual(loss.shape, (0,))
...@@ -394,6 +426,59 @@ class DeepMACMetaArchTest(tf.test.TestCase): ...@@ -394,6 +426,59 @@ class DeepMACMetaArchTest(tf.test.TestCase):
out = call_func(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 8)), training=True) out = call_func(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 8)), training=True)
self.assertEqual(out.shape, (2, 32, 32)) self.assertEqual(out.shape, (2, 32, 32))
def test_box_consistency_loss(self):
boxes_gt = tf.constant([[0., 0., 0.49, 1.0]])
boxes_jittered = tf.constant([[0.0, 0.0, 1.0, 1.0]])
mask_prediction = np.zeros((1, 32, 32)).astype(np.float32)
mask_prediction[0, :24, :24] = 1.0
loss = self.model._compute_per_instance_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
yloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.constant([1.0] * 8 + [0.0] * 8),
logits=[1.0] * 12 + [0.0] * 4)
xloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.constant([1.0] * 16),
logits=[1.0] * 12 + [0.0] * 4)
self.assertAllClose(loss, [tf.reduce_mean(yloss + xloss).numpy()])
def test_box_consistency_dice_loss(self):
model = build_meta_arch(use_dice_loss=True)
boxes_gt = tf.constant([[0., 0., 0.49, 1.0]])
boxes_jittered = tf.constant([[0.0, 0.0, 1.0, 1.0]])
almost_inf = 1e10
mask_prediction = np.full((1, 32, 32), -almost_inf, dtype=np.float32)
mask_prediction[0, :24, :24] = almost_inf
loss = model._compute_per_instance_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
yloss = 1 - 6.0 / 7
xloss = 0.2
self.assertAllClose(loss, [yloss + xloss])
def test_box_consistency_dice_loss_full_res(self):
model = build_meta_arch(use_dice_loss=True,
predict_full_resolution_masks=True)
boxes_gt = tf.constant([[0., 0., 1.0, 1.0]])
boxes_jittered = None
almost_inf = 1e10
mask_prediction = np.full((1, 32, 32), -almost_inf, dtype=np.float32)
mask_prediction[0, :16, :32] = almost_inf
loss = model._compute_per_instance_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
self.assertAlmostEqual(loss[0].numpy(), 1 / 3)
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class FullyConnectedMaskHeadTest(tf.test.TestCase): class FullyConnectedMaskHeadTest(tf.test.TestCase):
......
...@@ -446,6 +446,10 @@ message CenterNet { ...@@ -446,6 +446,10 @@ message CenterNet {
// The mode for jitterting box ROIs. See RandomJitterBoxes in // The mode for jitterting box ROIs. See RandomJitterBoxes in
// preprocessor.proto for more details // preprocessor.proto for more details
optional RandomJitterBoxes.JitterMode jitter_mode = 15 [default=DEFAULT]; optional RandomJitterBoxes.JitterMode jitter_mode = 15 [default=DEFAULT];
// Weight for the box consistency loss as described in the BoxInst paper
// https://arxiv.org/abs/2012.02310
optional float box_consistency_loss_weight = 16 [default=0.0];
} }
optional DeepMACMaskEstimation deepmac_mask_estimation = 14; optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment