Commit b47ca971 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Refactoring of Mask-RCNN to put all mask prediction code in third stage.

PiperOrigin-RevId: 192421843
parent 227f41e9
...@@ -451,8 +451,6 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -451,8 +451,6 @@ class FasterRCNNMetaArch(model.DetectionModel):
if self._number_of_stages <= 0 or self._number_of_stages > 3: if self._number_of_stages <= 0 or self._number_of_stages > 3:
raise ValueError('Number of stages should be a value in {1, 2, 3}.') raise ValueError('Number of stages should be a value in {1, 2, 3}.')
if self._is_training and self._number_of_stages == 3:
self._number_of_stages = 2
@property @property
def first_stage_feature_extractor_scope(self): def first_stage_feature_extractor_scope(self):
...@@ -739,9 +737,6 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -739,9 +737,6 @@ class FasterRCNNMetaArch(model.DetectionModel):
of the image. of the image.
6) box_classifier_features: a 4-D float32 tensor representing the 6) box_classifier_features: a 4-D float32 tensor representing the
features for each proposal. features for each proposal.
7) mask_predictions: (optional) a 4-D tensor with shape
[total_num_padded_proposals, num_classes, mask_height, mask_width]
containing instance mask predictions.
""" """
image_shape_2d = self._image_batch_shape_2d(image_shape) image_shape_2d = self._image_batch_shape_2d(image_shape)
proposal_boxes_normalized, _, num_proposals = self._postprocess_rpn( proposal_boxes_normalized, _, num_proposals = self._postprocess_rpn(
...@@ -757,15 +752,11 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -757,15 +752,11 @@ class FasterRCNNMetaArch(model.DetectionModel):
flattened_proposal_feature_maps, flattened_proposal_feature_maps,
scope=self.second_stage_feature_extractor_scope)) scope=self.second_stage_feature_extractor_scope))
predict_auxiliary_outputs = False
if self._number_of_stages == 2:
predict_auxiliary_outputs = True
box_predictions = self._mask_rcnn_box_predictor.predict( box_predictions = self._mask_rcnn_box_predictor.predict(
[box_classifier_features], [box_classifier_features],
num_predictions_per_location=[1], num_predictions_per_location=[1],
scope=self.second_stage_box_predictor_scope, scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=True, predict_boxes_and_classes=True)
predict_auxiliary_outputs=predict_auxiliary_outputs)
refined_box_encodings = tf.squeeze( refined_box_encodings = tf.squeeze(
box_predictions[box_predictor.BOX_ENCODINGS], box_predictions[box_predictor.BOX_ENCODINGS],
...@@ -786,18 +777,16 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -786,18 +777,16 @@ class FasterRCNNMetaArch(model.DetectionModel):
'box_classifier_features': box_classifier_features, 'box_classifier_features': box_classifier_features,
'proposal_boxes_normalized': proposal_boxes_normalized, 'proposal_boxes_normalized': proposal_boxes_normalized,
} }
if box_predictor.MASK_PREDICTIONS in box_predictions:
mask_predictions = tf.squeeze(box_predictions[
box_predictor.MASK_PREDICTIONS], axis=1)
prediction_dict['mask_predictions'] = mask_predictions
return prediction_dict return prediction_dict
def _predict_third_stage(self, prediction_dict, image_shapes): def _predict_third_stage(self, prediction_dict, image_shapes):
"""Predicts non-box, non-class outputs using refined detections. """Predicts non-box, non-class outputs using refined detections.
This happens after calling the post-processing stage, such that masks For training, masks as predicted directly on the box_classifier_features,
are only calculated for the top scored boxes. which are region-features from the initial anchor boxes.
For inference, this happens after calling the post-processing stage, such
that masks are only calculated for the top scored boxes.
Args: Args:
prediction_dict: a dictionary holding "raw" prediction tensors: prediction_dict: a dictionary holding "raw" prediction tensors:
...@@ -819,47 +808,62 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -819,47 +808,62 @@ class FasterRCNNMetaArch(model.DetectionModel):
4) proposal_boxes: A float32 tensor of shape 4) proposal_boxes: A float32 tensor of shape
[batch_size, self.max_num_proposals, 4] representing [batch_size, self.max_num_proposals, 4] representing
decoded proposal bounding boxes in absolute coordinates. decoded proposal bounding boxes in absolute coordinates.
5) box_classifier_features: a 4-D float32 tensor representing the
features for each proposal.
image_shapes: A 2-D int32 tensors of shape [batch_size, 3] containing image_shapes: A 2-D int32 tensors of shape [batch_size, 3] containing
shapes of images in the batch. shapes of images in the batch.
Returns: Returns:
prediction_dict: a dictionary that in addition to the input predictions prediction_dict: a dictionary that in addition to the input predictions
does hold the following predictions as well: does hold the following predictions as well:
1) mask_predictions: (optional) a 4-D tensor with shape 1) mask_predictions: a 4-D tensor with shape
[batch_size, max_detection, mask_height, mask_width] containing [batch_size, max_detection, mask_height, mask_width] containing
instance mask predictions. instance mask predictions.
""" """
detections_dict = self._postprocess_box_classifier( if self._is_training:
prediction_dict['refined_box_encodings'], curr_box_classifier_features = prediction_dict['box_classifier_features']
prediction_dict['class_predictions_with_background'], detection_classes = prediction_dict['class_predictions_with_background']
prediction_dict['proposal_boxes'], box_predictions = self._mask_rcnn_box_predictor.predict(
prediction_dict['num_proposals'], [curr_box_classifier_features],
image_shapes) num_predictions_per_location=[1],
prediction_dict.update(detections_dict) scope=self.second_stage_box_predictor_scope,
detection_boxes = detections_dict[ predict_boxes_and_classes=False,
fields.DetectionResultFields.detection_boxes] predict_auxiliary_outputs=True)
detection_classes = detections_dict[ prediction_dict['mask_predictions'] = tf.squeeze(box_predictions[
fields.DetectionResultFields.detection_classes] box_predictor.MASK_PREDICTIONS], axis=1)
rpn_features_to_crop = prediction_dict['rpn_features_to_crop'] else:
batch_size = tf.shape(detection_boxes)[0] detections_dict = self._postprocess_box_classifier(
max_detection = tf.shape(detection_boxes)[1] prediction_dict['refined_box_encodings'],
flattened_detected_feature_maps = ( prediction_dict['class_predictions_with_background'],
self._compute_second_stage_input_feature_maps( prediction_dict['proposal_boxes'],
rpn_features_to_crop, detection_boxes)) prediction_dict['num_proposals'],
detected_box_classifier_features = ( image_shapes)
self._feature_extractor.extract_box_classifier_features( prediction_dict.update(detections_dict)
flattened_detected_feature_maps, detection_boxes = detections_dict[
scope=self.second_stage_feature_extractor_scope)) fields.DetectionResultFields.detection_boxes]
box_predictions = self._mask_rcnn_box_predictor.predict( detection_classes = detections_dict[
[detected_box_classifier_features], fields.DetectionResultFields.detection_classes]
num_predictions_per_location=[1], rpn_features_to_crop = prediction_dict['rpn_features_to_crop']
scope=self.second_stage_box_predictor_scope, batch_size = tf.shape(detection_boxes)[0]
predict_boxes_and_classes=False, max_detection = tf.shape(detection_boxes)[1]
predict_auxiliary_outputs=True) flattened_detected_feature_maps = (
self._compute_second_stage_input_feature_maps(
rpn_features_to_crop, detection_boxes))
curr_box_classifier_features = (
self._feature_extractor.extract_box_classifier_features(
flattened_detected_feature_maps,
scope=self.second_stage_feature_extractor_scope))
box_predictions = self._mask_rcnn_box_predictor.predict(
[curr_box_classifier_features],
num_predictions_per_location=[1],
scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=False,
predict_auxiliary_outputs=True)
if box_predictor.MASK_PREDICTIONS in box_predictions:
detection_masks = tf.squeeze(box_predictions[ detection_masks = tf.squeeze(box_predictions[
box_predictor.MASK_PREDICTIONS], axis=1) box_predictor.MASK_PREDICTIONS], axis=1)
_, num_classes, mask_height, mask_width = ( _, num_classes, mask_height, mask_width = (
detection_masks.get_shape().as_list()) detection_masks.get_shape().as_list())
_, max_detection = detection_classes.get_shape().as_list() _, max_detection = detection_classes.get_shape().as_list()
......
...@@ -170,7 +170,7 @@ class FasterRCNNMetaArchTest( ...@@ -170,7 +170,7 @@ class FasterRCNNMetaArchTest(
with test_graph.as_default(): with test_graph.as_default():
model = self._build_model( model = self._build_model(
is_training=True, is_training=True,
number_of_stages=2, number_of_stages=3,
second_stage_batch_size=7, second_stage_batch_size=7,
predict_masks=True, predict_masks=True,
masks_are_class_agnostic=masks_are_class_agnostic) masks_are_class_agnostic=masks_are_class_agnostic)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment