Commit a4d9c3a0 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Class agnostic masks for mask_rcnn

PiperOrigin-RevId: 192132440
parent bfd15ec1
......@@ -111,6 +111,8 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
mask_rcnn_box_predictor.mask_prediction_num_conv_layers),
mask_prediction_conv_depth=(
mask_rcnn_box_predictor.mask_prediction_conv_depth),
masks_are_class_agnostic=(
mask_rcnn_box_predictor.masks_are_class_agnostic),
predict_keypoints=mask_rcnn_box_predictor.predict_keypoints)
return box_predictor_object
......
......@@ -307,6 +307,7 @@ class MaskRCNNBoxPredictor(BoxPredictor):
mask_width=14,
mask_prediction_num_conv_layers=2,
mask_prediction_conv_depth=256,
masks_are_class_agnostic=False,
predict_keypoints=False):
"""Constructor.
......@@ -337,6 +338,8 @@ class MaskRCNNBoxPredictor(BoxPredictor):
to 0, the depth of the convolution layers will be automatically chosen
based on the number of object classes and the number of channels in the
image features.
masks_are_class_agnostic: Boolean determining if the mask-head is
class-agnostic or not.
predict_keypoints: Whether to predict keypoints insde detection boxes.
......@@ -357,6 +360,7 @@ class MaskRCNNBoxPredictor(BoxPredictor):
self._mask_width = mask_width
self._mask_prediction_num_conv_layers = mask_prediction_num_conv_layers
self._mask_prediction_conv_depth = mask_prediction_conv_depth
self._masks_are_class_agnostic = masks_are_class_agnostic
self._predict_keypoints = predict_keypoints
if self._predict_keypoints:
raise ValueError('Keypoint prediction is unimplemented.')
......@@ -473,8 +477,9 @@ class MaskRCNNBoxPredictor(BoxPredictor):
upsampled_features,
num_outputs=num_conv_channels,
kernel_size=[3, 3])
num_masks = 1 if self._masks_are_class_agnostic else self.num_classes
mask_predictions = slim.conv2d(upsampled_features,
num_outputs=self.num_classes,
num_outputs=num_masks,
activation_fn=None,
kernel_size=[3, 3])
return tf.expand_dims(
......
......@@ -768,9 +768,11 @@ class FasterRCNNMetaArch(model.DetectionModel):
predict_auxiliary_outputs=predict_auxiliary_outputs)
refined_box_encodings = tf.squeeze(
box_predictions[box_predictor.BOX_ENCODINGS], axis=1)
class_predictions_with_background = tf.squeeze(box_predictions[
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND], axis=1)
box_predictions[box_predictor.BOX_ENCODINGS],
axis=1, name='all_refined_box_encodings')
class_predictions_with_background = tf.squeeze(
box_predictions[box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND],
axis=1, name='all_class_predictions_with_background')
absolute_proposal_boxes = ops.normalized_to_image_coordinates(
proposal_boxes_normalized, image_shape, self._parallel_iterations)
......@@ -794,6 +796,9 @@ class FasterRCNNMetaArch(model.DetectionModel):
def _predict_third_stage(self, prediction_dict, image_shapes):
"""Predicts non-box, non-class outputs using refined detections.
This happens after calling the post-processing stage, such that masks
are only calculated for the top scored boxes.
Args:
prediction_dict: a dictionary holding "raw" prediction tensors:
1) refined_box_encodings: a 3-D tensor with shape
......@@ -851,16 +856,21 @@ class FasterRCNNMetaArch(model.DetectionModel):
scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=False,
predict_auxiliary_outputs=True)
if box_predictor.MASK_PREDICTIONS in box_predictions:
detection_masks = tf.squeeze(box_predictions[
box_predictor.MASK_PREDICTIONS], axis=1)
detection_masks = self._gather_instance_masks(detection_masks,
detection_classes)
mask_height = tf.shape(detection_masks)[1]
mask_width = tf.shape(detection_masks)[2]
_, num_classes, mask_height, mask_width = (
detection_masks.get_shape().as_list())
_, max_detection = detection_classes.get_shape().as_list()
if num_classes > 1:
detection_masks = self._gather_instance_masks(
detection_masks, detection_classes)
prediction_dict[fields.DetectionResultFields.detection_masks] = (
tf.reshape(detection_masks,
[batch_size, max_detection, mask_height, mask_width]))
return prediction_dict
def _gather_instance_masks(self, instance_masks, classes):
......@@ -874,16 +884,12 @@ class FasterRCNNMetaArch(model.DetectionModel):
Returns:
masks: a 3-D float32 tensor with shape [K, mask_height, mask_width].
"""
_, num_classes, height, width = instance_masks.get_shape().as_list()
k = tf.shape(instance_masks)[0]
num_mask_classes = tf.shape(instance_masks)[1]
instance_mask_height = tf.shape(instance_masks)[2]
instance_mask_width = tf.shape(instance_masks)[3]
classes = tf.reshape(classes, [-1])
instance_masks = tf.reshape(instance_masks, [
-1, instance_mask_height, instance_mask_width
])
return tf.gather(instance_masks,
tf.range(k) * num_mask_classes + tf.to_int32(classes))
instance_masks = tf.reshape(instance_masks, [-1, height, width])
classes = tf.to_int32(tf.reshape(classes, [-1]))
gather_idx = tf.range(k) * num_classes + classes
return tf.gather(instance_masks, gather_idx)
def _extract_rpn_feature_maps(self, preprocessed_inputs):
"""Extracts RPN features.
......@@ -1815,11 +1821,18 @@ class FasterRCNNMetaArch(model.DetectionModel):
# Pad the prediction_masks with to add zeros for background class to be
# consistent with class predictions.
prediction_masks_with_background = tf.pad(
prediction_masks, [[0, 0], [1, 0], [0, 0], [0, 0]])
prediction_masks_masked_by_class_targets = tf.boolean_mask(
prediction_masks_with_background,
tf.greater(one_hot_flat_cls_targets_with_background, 0))
if prediction_masks.get_shape().as_list()[1] == 1:
# Class agnostic masks or masks for one-class prediction. Logic for
# both cases is the same since background predictions are ignored
# through the batch_mask_target_weights.
prediction_masks_masked_by_class_targets = prediction_masks
else:
prediction_masks_with_background = tf.pad(
prediction_masks, [[0, 0], [1, 0], [0, 0], [0, 0]])
prediction_masks_masked_by_class_targets = tf.boolean_mask(
prediction_masks_with_background,
tf.greater(one_hot_flat_cls_targets_with_background, 0))
mask_height = prediction_masks.shape[2].value
mask_width = prediction_masks.shape[3].value
reshaped_prediction_masks = tf.reshape(
......
......@@ -15,6 +15,7 @@
"""Tests for object_detection.meta_architectures.faster_rcnn_meta_arch."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
......@@ -22,7 +23,8 @@ from object_detection.meta_architectures import faster_rcnn_meta_arch_test_lib
class FasterRCNNMetaArchTest(
faster_rcnn_meta_arch_test_lib.FasterRCNNMetaArchTestBase):
faster_rcnn_meta_arch_test_lib.FasterRCNNMetaArchTestBase,
parameterized.TestCase):
def test_postprocess_second_stage_only_inference_mode_with_masks(self):
model = self._build_model(
......@@ -83,8 +85,12 @@ class FasterRCNNMetaArchTest(
self.assertTrue(np.amax(detections_out['detection_masks'] <= 1.0))
self.assertTrue(np.amin(detections_out['detection_masks'] >= 0.0))
@parameterized.parameters(
{'masks_are_class_agnostic': False},
{'masks_are_class_agnostic': True},
)
def test_predict_correct_shapes_in_inference_mode_three_stages_with_masks(
self):
self, masks_are_class_agnostic):
batch_size = 2
image_size = 10
max_num_proposals = 8
......@@ -126,7 +132,8 @@ class FasterRCNNMetaArchTest(
is_training=False,
number_of_stages=3,
second_stage_batch_size=2,
predict_masks=True)
predict_masks=True,
masks_are_class_agnostic=masks_are_class_agnostic)
preprocessed_inputs = tf.placeholder(tf.float32, shape=input_shape)
_, true_image_shapes = model.preprocess(preprocessed_inputs)
result_tensor_dict = model.predict(preprocessed_inputs,
......@@ -153,16 +160,20 @@ class FasterRCNNMetaArchTest(
self.assertAllEqual(tensor_dict_out['detection_scores'].shape, [2, 5])
self.assertAllEqual(tensor_dict_out['num_detections'].shape, [2])
@parameterized.parameters(
{'masks_are_class_agnostic': False},
{'masks_are_class_agnostic': True},
)
def test_predict_gives_correct_shapes_in_train_mode_both_stages_with_masks(
self):
self, masks_are_class_agnostic):
test_graph = tf.Graph()
with test_graph.as_default():
model = self._build_model(
is_training=True,
number_of_stages=2,
second_stage_batch_size=7,
predict_masks=True)
predict_masks=True,
masks_are_class_agnostic=masks_are_class_agnostic)
batch_size = 2
image_size = 10
max_num_proposals = 7
......@@ -184,6 +195,7 @@ class FasterRCNNMetaArchTest(
groundtruth_classes_list)
result_tensor_dict = model.predict(preprocessed_inputs, true_image_shapes)
mask_shape_1 = 1 if masks_are_class_agnostic else model._num_classes
expected_shapes = {
'rpn_box_predictor_features': (2, image_size, image_size, 512),
'rpn_features_to_crop': (2, image_size, image_size, 3),
......@@ -197,7 +209,7 @@ class FasterRCNNMetaArchTest(
self._get_box_classifier_features_shape(
image_size, batch_size, max_num_proposals, initial_crop_size,
maxpool_stride, 3),
'mask_predictions': (2 * max_num_proposals, 2, 14, 14)
'mask_predictions': (2 * max_num_proposals, mask_shape_1, 14, 14)
}
init_op = tf.global_variables_initializer()
......
......@@ -90,10 +90,13 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
"""
return box_predictor_text_proto
def _add_mask_to_second_stage_box_predictor_text_proto(self):
def _add_mask_to_second_stage_box_predictor_text_proto(
self, masks_are_class_agnostic=False):
agnostic = 'true' if masks_are_class_agnostic else 'false'
box_predictor_text_proto = """
mask_rcnn_box_predictor {
predict_instance_masks: true
masks_are_class_agnostic: """ + agnostic + """
mask_height: 14
mask_width: 14
conv_hyperparams {
......@@ -114,13 +117,14 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
return box_predictor_text_proto
def _get_second_stage_box_predictor(self, num_classes, is_training,
predict_masks):
predict_masks, masks_are_class_agnostic):
box_predictor_proto = box_predictor_pb2.BoxPredictor()
text_format.Merge(self._get_second_stage_box_predictor_text_proto(),
box_predictor_proto)
if predict_masks:
text_format.Merge(
self._add_mask_to_second_stage_box_predictor_text_proto(),
self._add_mask_to_second_stage_box_predictor_text_proto(
masks_are_class_agnostic),
box_predictor_proto)
return box_predictor_builder.build(
......@@ -146,7 +150,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
hard_mining=False,
softmax_second_stage_classification_loss=True,
predict_masks=False,
pad_to_max_dimension=None):
pad_to_max_dimension=None,
masks_are_class_agnostic=False):
def image_resizer_fn(image, masks=None):
"""Fake image resizer function."""
......@@ -287,7 +292,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
self._get_second_stage_box_predictor(
num_classes=num_classes,
is_training=is_training,
predict_masks=predict_masks), **common_kwargs)
predict_masks=predict_masks,
masks_are_class_agnostic=masks_are_class_agnostic), **common_kwargs)
def test_predict_gives_correct_shapes_in_inference_mode_first_stage_only(
self):
......
......@@ -118,6 +118,7 @@ message MaskRCNNBoxPredictor {
// The number of convolutions applied to image_features in the mask prediction
// branch.
optional int32 mask_prediction_num_conv_layers = 11 [default = 2];
optional bool masks_are_class_agnostic = 12 [default = false];
}
message RfcnBoxPredictor {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment