"test/vscode:/vscode.git/clone" did not exist on "5ff25cdf5b1310e83d9e595142b39ae4d7b561e9"
Commit 1cbb7a7d authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Add box consistency loss in DeepMAC.

PiperOrigin-RevId: 385646058
parent 76640072
......@@ -26,6 +26,7 @@ from object_detection.utils import spatial_transform_ops
INSTANCE_EMBEDDING = 'INSTANCE_EMBEDDING'
PIXEL_EMBEDDING = 'PIXEL_EMBEDDING'
DEEP_MASK_ESTIMATION = 'deep_mask_estimation'
DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency'
LOSS_KEY_PREFIX = center_net_meta_arch.LOSS_KEY_PREFIX
......@@ -35,7 +36,7 @@ class DeepMACParams(
'allowed_masked_classes_ids', 'mask_size', 'mask_num_subsamples',
'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels',
'predict_full_resolution_masks', 'postprocess_crop_size',
'max_roi_jitter_ratio', 'roi_jitter_mode'
'max_roi_jitter_ratio', 'roi_jitter_mode', 'box_consistency_loss_weight'
])):
"""Class holding the DeepMAC network configutration."""
......@@ -46,7 +47,7 @@ class DeepMACParams(
mask_num_subsamples, use_xy, network_type, use_instance_embedding,
num_init_channels, predict_full_resolution_masks,
postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode):
roi_jitter_mode, box_consistency_loss_weight):
return super(DeepMACParams,
cls).__new__(cls, classification_loss, dim,
task_loss_weight, pixel_embedding_dim,
......@@ -55,7 +56,7 @@ class DeepMACParams(
use_instance_embedding, num_init_channels,
predict_full_resolution_masks,
postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode)
roi_jitter_mode, box_consistency_loss_weight)
def subsample_instances(classes, weights, boxes, masks, num_subsamples):
......@@ -206,6 +207,61 @@ def filter_masked_classes(masked_class_ids, classes, weights, masks):
)
def crop_and_resize_feature_map(features, boxes, size):
"""Crop and resize regions from a single feature map given a set of boxes.
Args:
features: A [H, W, C] float tensor.
boxes: A [N, 4] tensor of norrmalized boxes.
size: int, the size of the output features.
Returns:
per_box_features: A [N, size, size, C] tensor of cropped and resized
features.
"""
return spatial_transform_ops.matmul_crop_and_resize(
features[tf.newaxis], boxes[tf.newaxis], [size, size])[0]
def crop_and_resize_instance_masks(masks, boxes, mask_size):
"""Crop and resize each mask according to the given boxes.
Args:
masks: A [N, H, W] float tensor.
boxes: A [N, 4] float tensor of normalized boxes.
mask_size: int, the size of the output masks.
Returns:
masks: A [N, mask_size, mask_size] float tensor of cropped and resized
instance masks.
"""
cropped_masks = spatial_transform_ops.matmul_crop_and_resize(
masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :],
[mask_size, mask_size])
cropped_masks = tf.squeeze(cropped_masks, axis=[1, 4])
return cropped_masks
def fill_boxes(boxes, height, width):
"""Fills the area included in the box."""
blist = box_list.BoxList(boxes)
blist = box_list_ops.to_absolute_coordinates(blist, height, width)
boxes = blist.get()
ymin, xmin, ymax, xmax = tf.unstack(
boxes[:, tf.newaxis, tf.newaxis, :], 4, axis=3)
ygrid, xgrid = tf.meshgrid(tf.range(height), tf.range(width), indexing='ij')
ygrid, xgrid = tf.cast(ygrid, tf.float32), tf.cast(xgrid, tf.float32)
ygrid, xgrid = ygrid[tf.newaxis, :, :], xgrid[tf.newaxis, :, :]
filled_boxes = tf.logical_and(
tf.logical_and(ygrid >= ymin, ygrid <= ymax),
tf.logical_and(xgrid >= xmin, xgrid <= xmax))
return tf.cast(filled_boxes, tf.float32)
class ResNetMaskNetwork(tf.keras.layers.Layer):
"""A small wrapper around ResNet blocks to predict masks."""
......@@ -379,7 +435,8 @@ def deepmac_proto_to_params(deepmac_config):
deepmac_config.predict_full_resolution_masks,
postprocess_crop_size=deepmac_config.postprocess_crop_size,
max_roi_jitter_ratio=deepmac_config.max_roi_jitter_ratio,
roi_jitter_mode=jitter_mode
roi_jitter_mode=jitter_mode,
box_consistency_loss_weight=deepmac_config.box_consistency_loss_weight
)
......@@ -402,6 +459,13 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
"""Constructs the super class with object center & detection params only."""
self._deepmac_params = deepmac_params
if (self._deepmac_params.predict_full_resolution_masks and
self._deepmac_params.max_roi_jitter_ratio > 0.0):
raise ValueError('Jittering is not supported for full res masks.')
if self._deepmac_params.mask_num_subsamples > 0:
raise ValueError('Subsampling masks is currently not supported.')
super(DeepMACMetaArch, self).__init__(
is_training=is_training, add_summaries=add_summaries,
num_classes=num_classes, feature_extractor=feature_extractor,
......@@ -462,18 +526,31 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
pixel_embedding = pixel_embedding[tf.newaxis, :, :, :]
pixel_embeddings_processed = tf.tile(pixel_embedding,
[num_instances, 1, 1, 1])
image_shape = tf.shape(pixel_embeddings_processed)
image_height, image_width = image_shape[1], image_shape[2]
y_grid, x_grid = tf.meshgrid(tf.linspace(0.0, 1.0, image_height),
tf.linspace(0.0, 1.0, image_width),
indexing='ij')
blist = box_list.BoxList(boxes)
ycenter, xcenter, _, _ = blist.get_center_coordinates_and_sizes()
y_grid = y_grid[tf.newaxis, :, :]
x_grid = x_grid[tf.newaxis, :, :]
y_grid -= ycenter[:, tf.newaxis, tf.newaxis]
x_grid -= xcenter[:, tf.newaxis, tf.newaxis]
coords = tf.stack([y_grid, x_grid], axis=3)
else:
# TODO(vighneshb) Explore multilevel_roi_align and align_corners=False.
pixel_embeddings_cropped = spatial_transform_ops.matmul_crop_and_resize(
pixel_embedding[tf.newaxis], boxes[tf.newaxis],
[mask_size, mask_size])
pixel_embeddings_processed = pixel_embeddings_cropped[0]
pixel_embeddings_processed = crop_and_resize_feature_map(
pixel_embedding, boxes, mask_size)
mask_shape = tf.shape(pixel_embeddings_processed)
mask_height, mask_width = mask_shape[1], mask_shape[2]
y_grid, x_grid = tf.meshgrid(tf.linspace(-1.0, 1.0, mask_height),
tf.linspace(-1.0, 1.0, mask_width),
indexing='ij')
coords = tf.stack([y_grid, x_grid], axis=2)
coords = coords[tf.newaxis, :, :, :]
coords = tf.tile(coords, [num_instances, 1, 1, 1])
......@@ -528,11 +605,9 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
if self._deepmac_params.predict_full_resolution_masks:
return masks
else:
cropped_masks = spatial_transform_ops.matmul_crop_and_resize(
masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :],
[mask_size, mask_size])
cropped_masks = crop_and_resize_instance_masks(
masks, boxes, mask_size)
cropped_masks = tf.stop_gradient(cropped_masks)
cropped_masks = tf.squeeze(cropped_masks, axis=[1, 4])
# TODO(vighneshb) should we discretize masks?
return cropped_masks
......@@ -543,7 +618,64 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return resize_instance_masks(logits, (height, width))
def _compute_per_instance_mask_loss(
def _compute_per_instance_mask_prediction_loss(
self, boxes, mask_logits, mask_gt):
num_instances = tf.shape(boxes)[0]
mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt)
mask_logits = tf.reshape(mask_logits, [num_instances, -1, 1])
mask_gt = tf.reshape(mask_gt, [num_instances, -1, 1])
loss = self._deepmac_params.classification_loss(
prediction_tensor=mask_logits,
target_tensor=mask_gt,
weights=tf.ones_like(mask_logits))
# TODO(vighneshb) Make this configurable via config.
# Skip normalization for dice loss because the denominator term already
# does normalization.
if isinstance(self._deepmac_params.classification_loss,
losses.WeightedDiceClassificationLoss):
return tf.reduce_sum(loss, axis=1)
else:
return tf.reduce_mean(loss, axis=[1, 2])
def _compute_per_instance_box_consistency_loss(
self, boxes_gt, boxes_for_crop, mask_logits):
height, width = tf.shape(mask_logits)[1], tf.shape(mask_logits)[2]
filled_boxes = fill_boxes(boxes_gt, height, width)[:, :, :, tf.newaxis]
mask_logits = mask_logits[:, :, :, tf.newaxis]
if self._deepmac_params.predict_full_resolution_masks:
gt_crop = filled_boxes[:, :, :, 0]
pred_crop = mask_logits[:, :, :, 0]
else:
gt_crop = crop_and_resize_instance_masks(
filled_boxes, boxes_for_crop, self._deepmac_params.mask_size)
pred_crop = crop_and_resize_instance_masks(
mask_logits, boxes_for_crop, self._deepmac_params.mask_size)
loss = 0.0
for axis in [1, 2]:
pred_max = tf.reduce_max(pred_crop, axis=axis)[:, :, tf.newaxis]
gt_max = tf.reduce_max(gt_crop, axis=axis)[:, :, tf.newaxis]
axis_loss = self._deepmac_params.classification_loss(
prediction_tensor=pred_max,
target_tensor=gt_max,
weights=tf.ones_like(pred_max))
loss += axis_loss
# Skip normalization for dice loss because the denominator term already
# does normalization.
# TODO(vighneshb) Make this configurable via config.
if isinstance(self._deepmac_params.classification_loss,
losses.WeightedDiceClassificationLoss):
return tf.reduce_sum(loss, axis=1)
else:
return tf.reduce_mean(loss, axis=[1, 2])
def _compute_per_instance_deepmac_losses(
self, boxes, masks, instance_embedding, pixel_embedding):
"""Returns the mask loss per instance.
......@@ -558,40 +690,36 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
pixel_embedding_size] float tensor containing the per-pixel embeddings.
Returns:
mask_loss: A [num_instances] shaped float tensor containing the
mask_prediction_loss: A [num_instances] shaped float tensor containing the
mask loss for each instance.
"""
box_consistency_loss: A [num_instances] shaped float tensor containing
the box consistency loss for each instance.
num_instances = tf.shape(boxes)[0]
"""
if tf.keras.backend.learning_phase():
boxes = preprocessor.random_jitter_boxes(
boxes_for_crop = preprocessor.random_jitter_boxes(
boxes, self._deepmac_params.max_roi_jitter_ratio,
jitter_mode=self._deepmac_params.roi_jitter_mode)
else:
boxes_for_crop = boxes
mask_input = self._get_mask_head_input(
boxes, pixel_embedding)
boxes_for_crop, pixel_embedding)
instance_embeddings = self._get_instance_embeddings(
boxes, instance_embedding)
boxes_for_crop, instance_embedding)
mask_logits = self._mask_net(
instance_embeddings, mask_input,
training=tf.keras.backend.learning_phase())
mask_gt = self._get_groundtruth_mask_output(boxes, masks)
mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt)
mask_gt = self._get_groundtruth_mask_output(boxes_for_crop, masks)
mask_logits = tf.reshape(mask_logits, [num_instances, -1, 1])
mask_gt = tf.reshape(mask_gt, [num_instances, -1, 1])
loss = self._deepmac_params.classification_loss(
prediction_tensor=mask_logits,
target_tensor=mask_gt,
weights=tf.ones_like(mask_logits))
mask_prediction_loss = self._compute_per_instance_mask_prediction_loss(
boxes_for_crop, mask_logits, mask_gt)
# TODO(vighneshb) Make this configurable via config.
if isinstance(self._deepmac_params.classification_loss,
losses.WeightedDiceClassificationLoss):
return tf.reduce_sum(loss, axis=1)
else:
return tf.reduce_mean(loss, axis=[1, 2])
box_consistency_loss = self._compute_per_instance_box_consistency_loss(
boxes, boxes_for_crop, mask_logits)
return mask_prediction_loss, box_consistency_loss
def _compute_instance_masks_loss(self, prediction_dict):
"""Computes the mask loss.
......@@ -603,7 +731,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
[batch_size, height, width, embedding_size].
Returns:
loss: float, the mask loss as a scalar.
loss_dict: A dict mapping string (loss names) to scalar floats.
"""
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
......@@ -613,7 +741,10 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
allowed_masked_classes_ids = (
self._deepmac_params.allowed_masked_classes_ids)
total_loss = 0.0
loss_dict = {
DEEP_MASK_ESTIMATION: 0.0,
DEEP_MASK_BOX_CONSISTENCY: 0.0
}
# Iterate over multiple preidctions by backbone (for hourglass length=2)
for instance_pred, pixel_pred in zip(
......@@ -625,24 +756,31 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
for i, (boxes, weights, classes, masks) in enumerate(
zip(gt_boxes_list, gt_weights_list, gt_classes_list, gt_masks_list)):
_, weights, masks = filter_masked_classes(allowed_masked_classes_ids,
classes, weights, masks)
num_subsample = self._deepmac_params.mask_num_subsamples
_, weights, boxes, masks = subsample_instances(
classes, weights, boxes, masks, num_subsample)
# TODO(vighneshb) Add sub-sampling back if required.
classes, valid_mask_weights, masks = filter_masked_classes(
allowed_masked_classes_ids, classes, weights, masks)
per_instance_loss = self._compute_per_instance_mask_loss(
boxes, masks, instance_pred[i], pixel_pred[i])
per_instance_loss *= weights
per_instance_mask_loss, per_instance_consistency_loss = (
self._compute_per_instance_deepmac_losses(
boxes, masks, instance_pred[i], pixel_pred[i]))
per_instance_mask_loss *= valid_mask_weights
per_instance_consistency_loss *= weights
num_instances = tf.maximum(tf.reduce_sum(weights), 1.0)
num_instances_allowed = tf.maximum(
tf.reduce_sum(valid_mask_weights), 1.0)
loss_dict[DEEP_MASK_ESTIMATION] += (
tf.reduce_sum(per_instance_mask_loss) / num_instances_allowed)
total_loss += tf.reduce_sum(per_instance_loss) / num_instances
loss_dict[DEEP_MASK_BOX_CONSISTENCY] += (
tf.reduce_sum(per_instance_consistency_loss) / num_instances)
batch_size = len(gt_boxes_list)
num_predictions = len(prediction_dict[INSTANCE_EMBEDDING])
return total_loss / float(batch_size * num_predictions)
return dict((key, loss / float(batch_size * num_predictions))
for key, loss in loss_dict.items())
def loss(self, prediction_dict, true_image_shapes, scope=None):
......@@ -650,13 +788,19 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
prediction_dict, true_image_shapes, scope)
if self._deepmac_params is not None:
mask_loss = self._compute_instance_masks_loss(
mask_loss_dict = self._compute_instance_masks_loss(
prediction_dict=prediction_dict)
key = LOSS_KEY_PREFIX + '/' + DEEP_MASK_ESTIMATION
losses_dict[key] = (
self._deepmac_params.task_loss_weight * mask_loss
losses_dict[LOSS_KEY_PREFIX + '/' + DEEP_MASK_ESTIMATION] = (
self._deepmac_params.task_loss_weight * mask_loss_dict[
DEEP_MASK_ESTIMATION]
)
if self._deepmac_params.box_consistency_loss_weight > 0.0:
losses_dict[LOSS_KEY_PREFIX + '/' + DEEP_MASK_BOX_CONSISTENCY] = (
self._deepmac_params.box_consistency_loss_weight * mask_loss_dict[
DEEP_MASK_BOX_CONSISTENCY]
)
return losses_dict
def postprocess(self, prediction_dict, true_image_shapes, **params):
......
......@@ -60,7 +60,8 @@ class MockMaskNet(tf.keras.layers.Layer):
return tf.zeros_like(pixel_embedding[:, :, :, 0]) + 0.9
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
mask_num_subsamples=-1):
"""Builds the DeepMAC meta architecture."""
feature_extractor = DummyFeatureExtractor(
......@@ -94,7 +95,7 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
pixel_embedding_dim=2,
allowed_masked_classes_ids=[],
mask_size=16,
mask_num_subsamples=-1,
mask_num_subsamples=mask_num_subsamples,
use_xy=True,
network_type='hourglass10',
use_instance_embedding=True,
......@@ -102,7 +103,8 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
predict_full_resolution_masks=predict_full_resolution_masks,
postprocess_crop_size=128,
max_roi_jitter_ratio=0.0,
roi_jitter_mode='random'
roi_jitter_mode='random',
box_consistency_loss_weight=1.0,
)
object_detection_params = center_net_meta_arch.ObjectDetectionParams(
......@@ -140,6 +142,33 @@ class DeepMACUtilsTest(tf.test.TestCase):
self.assertAllClose(result[2], boxes)
self.assertAllClose(result[3], masks)
def test_fill_boxes(self):
boxes = tf.constant([[0., 0., 0.5, 0.5], [0.5, 0.5, 1.0, 1.0]])
filled_boxes = deepmac_meta_arch.fill_boxes(boxes, 32, 32)
expected = np.zeros((2, 32, 32))
expected[0, :17, :17] = 1.0
expected[1, 16:, 16:] = 1.0
self.assertAllClose(expected, filled_boxes.numpy(), rtol=1e-3)
def test_crop_and_resize_instance_masks(self):
boxes = tf.zeros((5, 4))
masks = tf.zeros((5, 128, 128))
output = deepmac_meta_arch.crop_and_resize_instance_masks(
masks, boxes, 32)
self.assertEqual(output.shape, (5, 32, 32))
def test_crop_and_resize_feature_map(self):
boxes = tf.zeros((5, 4))
features = tf.zeros((128, 128, 7))
output = deepmac_meta_arch.crop_and_resize_feature_map(
features, boxes, 32)
self.assertEqual(output.shape, (5, 32, 32, 7))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMetaArchTest(tf.test.TestCase):
......@@ -199,7 +228,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
def test_get_mask_head_input_no_crop_resize(self):
model = build_meta_arch(predict_full_resolution_masks=True)
boxes = tf.constant([[0., 0., 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
boxes = tf.constant([[0., 0., 1.0, 1.0], [0.0, 0.0, 0.5, 1.0]],
dtype=tf.float32)
pixel_embedding_np = np.random.randn(32, 32, 4).astype(np.float32)
......@@ -208,12 +237,15 @@ class DeepMACMetaArchTest(tf.test.TestCase):
mask_inputs = model._get_mask_head_input(boxes, pixel_embedding)
self.assertEqual(mask_inputs.shape, (2, 32, 32, 6))
y_grid, x_grid = tf.meshgrid(np.linspace(-1.0, 1.0, 32),
np.linspace(-1.0, 1.0, 32), indexing='ij')
y_grid, x_grid = tf.meshgrid(np.linspace(.0, 1.0, 32),
np.linspace(.0, 1.0, 32), indexing='ij')
ys = [0.5, 0.25]
xs = [0.5, 0.5]
for i in range(2):
mask_input = mask_inputs[i]
self.assertAllClose(y_grid, mask_input[:, :, 0])
self.assertAllClose(x_grid, mask_input[:, :, 1])
self.assertAllClose(y_grid - ys[i], mask_input[:, :, 0])
self.assertAllClose(x_grid - xs[i], mask_input[:, :, 1])
pixel_embedding = mask_input[:, :, 2:]
self.assertAllClose(pixel_embedding_np, pixel_embedding)
......@@ -262,7 +294,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
masks[1, 16:, 16:] = 1.0
masks = tf.constant(masks)
loss = model._compute_per_instance_mask_loss(
loss, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)))
self.assertAllClose(
loss, np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9)))
......@@ -275,7 +307,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
masks = np.ones((2, 128, 128), dtype=np.float32)
masks = tf.constant(masks)
loss = model._compute_per_instance_mask_loss(
loss, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)))
self.assertAllClose(
loss, np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9)))
......@@ -289,7 +321,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
masks = np.ones((2, 128, 128), dtype=np.float32)
masks = tf.constant(masks)
loss = model._compute_per_instance_mask_loss(
loss, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)))
pred = tf.nn.sigmoid(0.9)
expected = (1.0 - ((2.0 * pred) / (1.0 + pred)))
......@@ -299,7 +331,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
boxes = tf.zeros([0, 4])
masks = tf.zeros([0, 128, 128])
loss = self.model._compute_per_instance_mask_loss(
loss, _ = self.model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)))
self.assertEqual(loss.shape, (0,))
......@@ -394,6 +426,59 @@ class DeepMACMetaArchTest(tf.test.TestCase):
out = call_func(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 8)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
def test_box_consistency_loss(self):
boxes_gt = tf.constant([[0., 0., 0.49, 1.0]])
boxes_jittered = tf.constant([[0.0, 0.0, 1.0, 1.0]])
mask_prediction = np.zeros((1, 32, 32)).astype(np.float32)
mask_prediction[0, :24, :24] = 1.0
loss = self.model._compute_per_instance_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
yloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.constant([1.0] * 8 + [0.0] * 8),
logits=[1.0] * 12 + [0.0] * 4)
xloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.constant([1.0] * 16),
logits=[1.0] * 12 + [0.0] * 4)
self.assertAllClose(loss, [tf.reduce_mean(yloss + xloss).numpy()])
def test_box_consistency_dice_loss(self):
model = build_meta_arch(use_dice_loss=True)
boxes_gt = tf.constant([[0., 0., 0.49, 1.0]])
boxes_jittered = tf.constant([[0.0, 0.0, 1.0, 1.0]])
almost_inf = 1e10
mask_prediction = np.full((1, 32, 32), -almost_inf, dtype=np.float32)
mask_prediction[0, :24, :24] = almost_inf
loss = model._compute_per_instance_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
yloss = 1 - 6.0 / 7
xloss = 0.2
self.assertAllClose(loss, [yloss + xloss])
def test_box_consistency_dice_loss_full_res(self):
model = build_meta_arch(use_dice_loss=True,
predict_full_resolution_masks=True)
boxes_gt = tf.constant([[0., 0., 1.0, 1.0]])
boxes_jittered = None
almost_inf = 1e10
mask_prediction = np.full((1, 32, 32), -almost_inf, dtype=np.float32)
mask_prediction[0, :16, :32] = almost_inf
loss = model._compute_per_instance_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
self.assertAlmostEqual(loss[0].numpy(), 1 / 3)
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class FullyConnectedMaskHeadTest(tf.test.TestCase):
......
......@@ -446,6 +446,10 @@ message CenterNet {
// The mode for jitterting box ROIs. See RandomJitterBoxes in
// preprocessor.proto for more details
optional RandomJitterBoxes.JitterMode jitter_mode = 15 [default=DEFAULT];
// Weight for the box consistency loss as described in the BoxInst paper
// https://arxiv.org/abs/2012.02310
optional float box_consistency_loss_weight = 16 [default=0.0];
}
optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment