Commit a4d9c3a0 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Class agnostic masks for mask_rcnn

PiperOrigin-RevId: 192132440
parent bfd15ec1
...@@ -111,6 +111,8 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): ...@@ -111,6 +111,8 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
mask_rcnn_box_predictor.mask_prediction_num_conv_layers), mask_rcnn_box_predictor.mask_prediction_num_conv_layers),
mask_prediction_conv_depth=( mask_prediction_conv_depth=(
mask_rcnn_box_predictor.mask_prediction_conv_depth), mask_rcnn_box_predictor.mask_prediction_conv_depth),
masks_are_class_agnostic=(
mask_rcnn_box_predictor.masks_are_class_agnostic),
predict_keypoints=mask_rcnn_box_predictor.predict_keypoints) predict_keypoints=mask_rcnn_box_predictor.predict_keypoints)
return box_predictor_object return box_predictor_object
......
...@@ -307,6 +307,7 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -307,6 +307,7 @@ class MaskRCNNBoxPredictor(BoxPredictor):
mask_width=14, mask_width=14,
mask_prediction_num_conv_layers=2, mask_prediction_num_conv_layers=2,
mask_prediction_conv_depth=256, mask_prediction_conv_depth=256,
masks_are_class_agnostic=False,
predict_keypoints=False): predict_keypoints=False):
"""Constructor. """Constructor.
...@@ -337,6 +338,8 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -337,6 +338,8 @@ class MaskRCNNBoxPredictor(BoxPredictor):
to 0, the depth of the convolution layers will be automatically chosen to 0, the depth of the convolution layers will be automatically chosen
based on the number of object classes and the number of channels in the based on the number of object classes and the number of channels in the
image features. image features.
masks_are_class_agnostic: Boolean determining if the mask-head is
class-agnostic or not.
predict_keypoints: Whether to predict keypoints insde detection boxes. predict_keypoints: Whether to predict keypoints insde detection boxes.
...@@ -357,6 +360,7 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -357,6 +360,7 @@ class MaskRCNNBoxPredictor(BoxPredictor):
self._mask_width = mask_width self._mask_width = mask_width
self._mask_prediction_num_conv_layers = mask_prediction_num_conv_layers self._mask_prediction_num_conv_layers = mask_prediction_num_conv_layers
self._mask_prediction_conv_depth = mask_prediction_conv_depth self._mask_prediction_conv_depth = mask_prediction_conv_depth
self._masks_are_class_agnostic = masks_are_class_agnostic
self._predict_keypoints = predict_keypoints self._predict_keypoints = predict_keypoints
if self._predict_keypoints: if self._predict_keypoints:
raise ValueError('Keypoint prediction is unimplemented.') raise ValueError('Keypoint prediction is unimplemented.')
...@@ -473,8 +477,9 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -473,8 +477,9 @@ class MaskRCNNBoxPredictor(BoxPredictor):
upsampled_features, upsampled_features,
num_outputs=num_conv_channels, num_outputs=num_conv_channels,
kernel_size=[3, 3]) kernel_size=[3, 3])
num_masks = 1 if self._masks_are_class_agnostic else self.num_classes
mask_predictions = slim.conv2d(upsampled_features, mask_predictions = slim.conv2d(upsampled_features,
num_outputs=self.num_classes, num_outputs=num_masks,
activation_fn=None, activation_fn=None,
kernel_size=[3, 3]) kernel_size=[3, 3])
return tf.expand_dims( return tf.expand_dims(
......
...@@ -768,9 +768,11 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -768,9 +768,11 @@ class FasterRCNNMetaArch(model.DetectionModel):
predict_auxiliary_outputs=predict_auxiliary_outputs) predict_auxiliary_outputs=predict_auxiliary_outputs)
refined_box_encodings = tf.squeeze( refined_box_encodings = tf.squeeze(
box_predictions[box_predictor.BOX_ENCODINGS], axis=1) box_predictions[box_predictor.BOX_ENCODINGS],
class_predictions_with_background = tf.squeeze(box_predictions[ axis=1, name='all_refined_box_encodings')
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND], axis=1) class_predictions_with_background = tf.squeeze(
box_predictions[box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND],
axis=1, name='all_class_predictions_with_background')
absolute_proposal_boxes = ops.normalized_to_image_coordinates( absolute_proposal_boxes = ops.normalized_to_image_coordinates(
proposal_boxes_normalized, image_shape, self._parallel_iterations) proposal_boxes_normalized, image_shape, self._parallel_iterations)
...@@ -794,6 +796,9 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -794,6 +796,9 @@ class FasterRCNNMetaArch(model.DetectionModel):
def _predict_third_stage(self, prediction_dict, image_shapes): def _predict_third_stage(self, prediction_dict, image_shapes):
"""Predicts non-box, non-class outputs using refined detections. """Predicts non-box, non-class outputs using refined detections.
This happens after calling the post-processing stage, such that masks
are only calculated for the top scored boxes.
Args: Args:
prediction_dict: a dictionary holding "raw" prediction tensors: prediction_dict: a dictionary holding "raw" prediction tensors:
1) refined_box_encodings: a 3-D tensor with shape 1) refined_box_encodings: a 3-D tensor with shape
...@@ -851,16 +856,21 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -851,16 +856,21 @@ class FasterRCNNMetaArch(model.DetectionModel):
scope=self.second_stage_box_predictor_scope, scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=False, predict_boxes_and_classes=False,
predict_auxiliary_outputs=True) predict_auxiliary_outputs=True)
if box_predictor.MASK_PREDICTIONS in box_predictions: if box_predictor.MASK_PREDICTIONS in box_predictions:
detection_masks = tf.squeeze(box_predictions[ detection_masks = tf.squeeze(box_predictions[
box_predictor.MASK_PREDICTIONS], axis=1) box_predictor.MASK_PREDICTIONS], axis=1)
detection_masks = self._gather_instance_masks(detection_masks, _, num_classes, mask_height, mask_width = (
detection_classes) detection_masks.get_shape().as_list())
mask_height = tf.shape(detection_masks)[1] _, max_detection = detection_classes.get_shape().as_list()
mask_width = tf.shape(detection_masks)[2] if num_classes > 1:
detection_masks = self._gather_instance_masks(
detection_masks, detection_classes)
prediction_dict[fields.DetectionResultFields.detection_masks] = ( prediction_dict[fields.DetectionResultFields.detection_masks] = (
tf.reshape(detection_masks, tf.reshape(detection_masks,
[batch_size, max_detection, mask_height, mask_width])) [batch_size, max_detection, mask_height, mask_width]))
return prediction_dict return prediction_dict
def _gather_instance_masks(self, instance_masks, classes): def _gather_instance_masks(self, instance_masks, classes):
...@@ -874,16 +884,12 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -874,16 +884,12 @@ class FasterRCNNMetaArch(model.DetectionModel):
Returns: Returns:
masks: a 3-D float32 tensor with shape [K, mask_height, mask_width]. masks: a 3-D float32 tensor with shape [K, mask_height, mask_width].
""" """
_, num_classes, height, width = instance_masks.get_shape().as_list()
k = tf.shape(instance_masks)[0] k = tf.shape(instance_masks)[0]
num_mask_classes = tf.shape(instance_masks)[1] instance_masks = tf.reshape(instance_masks, [-1, height, width])
instance_mask_height = tf.shape(instance_masks)[2] classes = tf.to_int32(tf.reshape(classes, [-1]))
instance_mask_width = tf.shape(instance_masks)[3] gather_idx = tf.range(k) * num_classes + classes
classes = tf.reshape(classes, [-1]) return tf.gather(instance_masks, gather_idx)
instance_masks = tf.reshape(instance_masks, [
-1, instance_mask_height, instance_mask_width
])
return tf.gather(instance_masks,
tf.range(k) * num_mask_classes + tf.to_int32(classes))
def _extract_rpn_feature_maps(self, preprocessed_inputs): def _extract_rpn_feature_maps(self, preprocessed_inputs):
"""Extracts RPN features. """Extracts RPN features.
...@@ -1815,11 +1821,18 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1815,11 +1821,18 @@ class FasterRCNNMetaArch(model.DetectionModel):
# Pad the prediction_masks with to add zeros for background class to be # Pad the prediction_masks with to add zeros for background class to be
# consistent with class predictions. # consistent with class predictions.
prediction_masks_with_background = tf.pad( if prediction_masks.get_shape().as_list()[1] == 1:
prediction_masks, [[0, 0], [1, 0], [0, 0], [0, 0]]) # Class agnostic masks or masks for one-class prediction. Logic for
prediction_masks_masked_by_class_targets = tf.boolean_mask( # both cases is the same since background predictions are ignored
prediction_masks_with_background, # through the batch_mask_target_weights.
tf.greater(one_hot_flat_cls_targets_with_background, 0)) prediction_masks_masked_by_class_targets = prediction_masks
else:
prediction_masks_with_background = tf.pad(
prediction_masks, [[0, 0], [1, 0], [0, 0], [0, 0]])
prediction_masks_masked_by_class_targets = tf.boolean_mask(
prediction_masks_with_background,
tf.greater(one_hot_flat_cls_targets_with_background, 0))
mask_height = prediction_masks.shape[2].value mask_height = prediction_masks.shape[2].value
mask_width = prediction_masks.shape[3].value mask_width = prediction_masks.shape[3].value
reshaped_prediction_masks = tf.reshape( reshaped_prediction_masks = tf.reshape(
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Tests for object_detection.meta_architectures.faster_rcnn_meta_arch.""" """Tests for object_detection.meta_architectures.faster_rcnn_meta_arch."""
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -22,7 +23,8 @@ from object_detection.meta_architectures import faster_rcnn_meta_arch_test_lib ...@@ -22,7 +23,8 @@ from object_detection.meta_architectures import faster_rcnn_meta_arch_test_lib
class FasterRCNNMetaArchTest( class FasterRCNNMetaArchTest(
faster_rcnn_meta_arch_test_lib.FasterRCNNMetaArchTestBase): faster_rcnn_meta_arch_test_lib.FasterRCNNMetaArchTestBase,
parameterized.TestCase):
def test_postprocess_second_stage_only_inference_mode_with_masks(self): def test_postprocess_second_stage_only_inference_mode_with_masks(self):
model = self._build_model( model = self._build_model(
...@@ -83,8 +85,12 @@ class FasterRCNNMetaArchTest( ...@@ -83,8 +85,12 @@ class FasterRCNNMetaArchTest(
self.assertTrue(np.amax(detections_out['detection_masks'] <= 1.0)) self.assertTrue(np.amax(detections_out['detection_masks'] <= 1.0))
self.assertTrue(np.amin(detections_out['detection_masks'] >= 0.0)) self.assertTrue(np.amin(detections_out['detection_masks'] >= 0.0))
@parameterized.parameters(
{'masks_are_class_agnostic': False},
{'masks_are_class_agnostic': True},
)
def test_predict_correct_shapes_in_inference_mode_three_stages_with_masks( def test_predict_correct_shapes_in_inference_mode_three_stages_with_masks(
self): self, masks_are_class_agnostic):
batch_size = 2 batch_size = 2
image_size = 10 image_size = 10
max_num_proposals = 8 max_num_proposals = 8
...@@ -126,7 +132,8 @@ class FasterRCNNMetaArchTest( ...@@ -126,7 +132,8 @@ class FasterRCNNMetaArchTest(
is_training=False, is_training=False,
number_of_stages=3, number_of_stages=3,
second_stage_batch_size=2, second_stage_batch_size=2,
predict_masks=True) predict_masks=True,
masks_are_class_agnostic=masks_are_class_agnostic)
preprocessed_inputs = tf.placeholder(tf.float32, shape=input_shape) preprocessed_inputs = tf.placeholder(tf.float32, shape=input_shape)
_, true_image_shapes = model.preprocess(preprocessed_inputs) _, true_image_shapes = model.preprocess(preprocessed_inputs)
result_tensor_dict = model.predict(preprocessed_inputs, result_tensor_dict = model.predict(preprocessed_inputs,
...@@ -153,16 +160,20 @@ class FasterRCNNMetaArchTest( ...@@ -153,16 +160,20 @@ class FasterRCNNMetaArchTest(
self.assertAllEqual(tensor_dict_out['detection_scores'].shape, [2, 5]) self.assertAllEqual(tensor_dict_out['detection_scores'].shape, [2, 5])
self.assertAllEqual(tensor_dict_out['num_detections'].shape, [2]) self.assertAllEqual(tensor_dict_out['num_detections'].shape, [2])
@parameterized.parameters(
{'masks_are_class_agnostic': False},
{'masks_are_class_agnostic': True},
)
def test_predict_gives_correct_shapes_in_train_mode_both_stages_with_masks( def test_predict_gives_correct_shapes_in_train_mode_both_stages_with_masks(
self): self, masks_are_class_agnostic):
test_graph = tf.Graph() test_graph = tf.Graph()
with test_graph.as_default(): with test_graph.as_default():
model = self._build_model( model = self._build_model(
is_training=True, is_training=True,
number_of_stages=2, number_of_stages=2,
second_stage_batch_size=7, second_stage_batch_size=7,
predict_masks=True) predict_masks=True,
masks_are_class_agnostic=masks_are_class_agnostic)
batch_size = 2 batch_size = 2
image_size = 10 image_size = 10
max_num_proposals = 7 max_num_proposals = 7
...@@ -184,6 +195,7 @@ class FasterRCNNMetaArchTest( ...@@ -184,6 +195,7 @@ class FasterRCNNMetaArchTest(
groundtruth_classes_list) groundtruth_classes_list)
result_tensor_dict = model.predict(preprocessed_inputs, true_image_shapes) result_tensor_dict = model.predict(preprocessed_inputs, true_image_shapes)
mask_shape_1 = 1 if masks_are_class_agnostic else model._num_classes
expected_shapes = { expected_shapes = {
'rpn_box_predictor_features': (2, image_size, image_size, 512), 'rpn_box_predictor_features': (2, image_size, image_size, 512),
'rpn_features_to_crop': (2, image_size, image_size, 3), 'rpn_features_to_crop': (2, image_size, image_size, 3),
...@@ -197,7 +209,7 @@ class FasterRCNNMetaArchTest( ...@@ -197,7 +209,7 @@ class FasterRCNNMetaArchTest(
self._get_box_classifier_features_shape( self._get_box_classifier_features_shape(
image_size, batch_size, max_num_proposals, initial_crop_size, image_size, batch_size, max_num_proposals, initial_crop_size,
maxpool_stride, 3), maxpool_stride, 3),
'mask_predictions': (2 * max_num_proposals, 2, 14, 14) 'mask_predictions': (2 * max_num_proposals, mask_shape_1, 14, 14)
} }
init_op = tf.global_variables_initializer() init_op = tf.global_variables_initializer()
......
...@@ -90,10 +90,13 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -90,10 +90,13 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
""" """
return box_predictor_text_proto return box_predictor_text_proto
def _add_mask_to_second_stage_box_predictor_text_proto(self): def _add_mask_to_second_stage_box_predictor_text_proto(
self, masks_are_class_agnostic=False):
agnostic = 'true' if masks_are_class_agnostic else 'false'
box_predictor_text_proto = """ box_predictor_text_proto = """
mask_rcnn_box_predictor { mask_rcnn_box_predictor {
predict_instance_masks: true predict_instance_masks: true
masks_are_class_agnostic: """ + agnostic + """
mask_height: 14 mask_height: 14
mask_width: 14 mask_width: 14
conv_hyperparams { conv_hyperparams {
...@@ -114,13 +117,14 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -114,13 +117,14 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
return box_predictor_text_proto return box_predictor_text_proto
def _get_second_stage_box_predictor(self, num_classes, is_training, def _get_second_stage_box_predictor(self, num_classes, is_training,
predict_masks): predict_masks, masks_are_class_agnostic):
box_predictor_proto = box_predictor_pb2.BoxPredictor() box_predictor_proto = box_predictor_pb2.BoxPredictor()
text_format.Merge(self._get_second_stage_box_predictor_text_proto(), text_format.Merge(self._get_second_stage_box_predictor_text_proto(),
box_predictor_proto) box_predictor_proto)
if predict_masks: if predict_masks:
text_format.Merge( text_format.Merge(
self._add_mask_to_second_stage_box_predictor_text_proto(), self._add_mask_to_second_stage_box_predictor_text_proto(
masks_are_class_agnostic),
box_predictor_proto) box_predictor_proto)
return box_predictor_builder.build( return box_predictor_builder.build(
...@@ -146,7 +150,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -146,7 +150,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
hard_mining=False, hard_mining=False,
softmax_second_stage_classification_loss=True, softmax_second_stage_classification_loss=True,
predict_masks=False, predict_masks=False,
pad_to_max_dimension=None): pad_to_max_dimension=None,
masks_are_class_agnostic=False):
def image_resizer_fn(image, masks=None): def image_resizer_fn(image, masks=None):
"""Fake image resizer function.""" """Fake image resizer function."""
...@@ -287,7 +292,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -287,7 +292,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
self._get_second_stage_box_predictor( self._get_second_stage_box_predictor(
num_classes=num_classes, num_classes=num_classes,
is_training=is_training, is_training=is_training,
predict_masks=predict_masks), **common_kwargs) predict_masks=predict_masks,
masks_are_class_agnostic=masks_are_class_agnostic), **common_kwargs)
def test_predict_gives_correct_shapes_in_inference_mode_first_stage_only( def test_predict_gives_correct_shapes_in_inference_mode_first_stage_only(
self): self):
......
...@@ -118,6 +118,7 @@ message MaskRCNNBoxPredictor { ...@@ -118,6 +118,7 @@ message MaskRCNNBoxPredictor {
// The number of convolutions applied to image_features in the mask prediction // The number of convolutions applied to image_features in the mask prediction
// branch. // branch.
optional int32 mask_prediction_num_conv_layers = 11 [default = 2]; optional int32 mask_prediction_num_conv_layers = 11 [default = 2];
optional bool masks_are_class_agnostic = 12 [default = false];
} }
message RfcnBoxPredictor { message RfcnBoxPredictor {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment