Commit b47ca971 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Refactoring of Mask-RCNN to put all mask prediction code in third stage.

PiperOrigin-RevId: 192421843
parent 227f41e9
...@@ -451,8 +451,6 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -451,8 +451,6 @@ class FasterRCNNMetaArch(model.DetectionModel):
if self._number_of_stages <= 0 or self._number_of_stages > 3: if self._number_of_stages <= 0 or self._number_of_stages > 3:
raise ValueError('Number of stages should be a value in {1, 2, 3}.') raise ValueError('Number of stages should be a value in {1, 2, 3}.')
if self._is_training and self._number_of_stages == 3:
self._number_of_stages = 2
@property @property
def first_stage_feature_extractor_scope(self): def first_stage_feature_extractor_scope(self):
...@@ -739,9 +737,6 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -739,9 +737,6 @@ class FasterRCNNMetaArch(model.DetectionModel):
of the image. of the image.
6) box_classifier_features: a 4-D float32 tensor representing the 6) box_classifier_features: a 4-D float32 tensor representing the
features for each proposal. features for each proposal.
7) mask_predictions: (optional) a 4-D tensor with shape
[total_num_padded_proposals, num_classes, mask_height, mask_width]
containing instance mask predictions.
""" """
image_shape_2d = self._image_batch_shape_2d(image_shape) image_shape_2d = self._image_batch_shape_2d(image_shape)
proposal_boxes_normalized, _, num_proposals = self._postprocess_rpn( proposal_boxes_normalized, _, num_proposals = self._postprocess_rpn(
...@@ -757,15 +752,11 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -757,15 +752,11 @@ class FasterRCNNMetaArch(model.DetectionModel):
flattened_proposal_feature_maps, flattened_proposal_feature_maps,
scope=self.second_stage_feature_extractor_scope)) scope=self.second_stage_feature_extractor_scope))
predict_auxiliary_outputs = False
if self._number_of_stages == 2:
predict_auxiliary_outputs = True
box_predictions = self._mask_rcnn_box_predictor.predict( box_predictions = self._mask_rcnn_box_predictor.predict(
[box_classifier_features], [box_classifier_features],
num_predictions_per_location=[1], num_predictions_per_location=[1],
scope=self.second_stage_box_predictor_scope, scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=True, predict_boxes_and_classes=True)
predict_auxiliary_outputs=predict_auxiliary_outputs)
refined_box_encodings = tf.squeeze( refined_box_encodings = tf.squeeze(
box_predictions[box_predictor.BOX_ENCODINGS], box_predictions[box_predictor.BOX_ENCODINGS],
...@@ -786,18 +777,16 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -786,18 +777,16 @@ class FasterRCNNMetaArch(model.DetectionModel):
'box_classifier_features': box_classifier_features, 'box_classifier_features': box_classifier_features,
'proposal_boxes_normalized': proposal_boxes_normalized, 'proposal_boxes_normalized': proposal_boxes_normalized,
} }
if box_predictor.MASK_PREDICTIONS in box_predictions:
mask_predictions = tf.squeeze(box_predictions[
box_predictor.MASK_PREDICTIONS], axis=1)
prediction_dict['mask_predictions'] = mask_predictions
return prediction_dict return prediction_dict
def _predict_third_stage(self, prediction_dict, image_shapes): def _predict_third_stage(self, prediction_dict, image_shapes):
"""Predicts non-box, non-class outputs using refined detections. """Predicts non-box, non-class outputs using refined detections.
This happens after calling the post-processing stage, such that masks For training, masks as predicted directly on the box_classifier_features,
are only calculated for the top scored boxes. which are region-features from the initial anchor boxes.
For inference, this happens after calling the post-processing stage, such
that masks are only calculated for the top scored boxes.
Args: Args:
prediction_dict: a dictionary holding "raw" prediction tensors: prediction_dict: a dictionary holding "raw" prediction tensors:
...@@ -819,16 +808,30 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -819,16 +808,30 @@ class FasterRCNNMetaArch(model.DetectionModel):
4) proposal_boxes: A float32 tensor of shape 4) proposal_boxes: A float32 tensor of shape
[batch_size, self.max_num_proposals, 4] representing [batch_size, self.max_num_proposals, 4] representing
decoded proposal bounding boxes in absolute coordinates. decoded proposal bounding boxes in absolute coordinates.
5) box_classifier_features: a 4-D float32 tensor representing the
features for each proposal.
image_shapes: A 2-D int32 tensors of shape [batch_size, 3] containing image_shapes: A 2-D int32 tensors of shape [batch_size, 3] containing
shapes of images in the batch. shapes of images in the batch.
Returns: Returns:
prediction_dict: a dictionary that in addition to the input predictions prediction_dict: a dictionary that in addition to the input predictions
does hold the following predictions as well: does hold the following predictions as well:
1) mask_predictions: (optional) a 4-D tensor with shape 1) mask_predictions: a 4-D tensor with shape
[batch_size, max_detection, mask_height, mask_width] containing [batch_size, max_detection, mask_height, mask_width] containing
instance mask predictions. instance mask predictions.
""" """
if self._is_training:
curr_box_classifier_features = prediction_dict['box_classifier_features']
detection_classes = prediction_dict['class_predictions_with_background']
box_predictions = self._mask_rcnn_box_predictor.predict(
[curr_box_classifier_features],
num_predictions_per_location=[1],
scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=False,
predict_auxiliary_outputs=True)
prediction_dict['mask_predictions'] = tf.squeeze(box_predictions[
box_predictor.MASK_PREDICTIONS], axis=1)
else:
detections_dict = self._postprocess_box_classifier( detections_dict = self._postprocess_box_classifier(
prediction_dict['refined_box_encodings'], prediction_dict['refined_box_encodings'],
prediction_dict['class_predictions_with_background'], prediction_dict['class_predictions_with_background'],
...@@ -846,20 +849,21 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -846,20 +849,21 @@ class FasterRCNNMetaArch(model.DetectionModel):
flattened_detected_feature_maps = ( flattened_detected_feature_maps = (
self._compute_second_stage_input_feature_maps( self._compute_second_stage_input_feature_maps(
rpn_features_to_crop, detection_boxes)) rpn_features_to_crop, detection_boxes))
detected_box_classifier_features = ( curr_box_classifier_features = (
self._feature_extractor.extract_box_classifier_features( self._feature_extractor.extract_box_classifier_features(
flattened_detected_feature_maps, flattened_detected_feature_maps,
scope=self.second_stage_feature_extractor_scope)) scope=self.second_stage_feature_extractor_scope))
box_predictions = self._mask_rcnn_box_predictor.predict( box_predictions = self._mask_rcnn_box_predictor.predict(
[detected_box_classifier_features], [curr_box_classifier_features],
num_predictions_per_location=[1], num_predictions_per_location=[1],
scope=self.second_stage_box_predictor_scope, scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=False, predict_boxes_and_classes=False,
predict_auxiliary_outputs=True) predict_auxiliary_outputs=True)
if box_predictor.MASK_PREDICTIONS in box_predictions:
detection_masks = tf.squeeze(box_predictions[ detection_masks = tf.squeeze(box_predictions[
box_predictor.MASK_PREDICTIONS], axis=1) box_predictor.MASK_PREDICTIONS], axis=1)
_, num_classes, mask_height, mask_width = ( _, num_classes, mask_height, mask_width = (
detection_masks.get_shape().as_list()) detection_masks.get_shape().as_list())
_, max_detection = detection_classes.get_shape().as_list() _, max_detection = detection_classes.get_shape().as_list()
......
...@@ -170,7 +170,7 @@ class FasterRCNNMetaArchTest( ...@@ -170,7 +170,7 @@ class FasterRCNNMetaArchTest(
with test_graph.as_default(): with test_graph.as_default():
model = self._build_model( model = self._build_model(
is_training=True, is_training=True,
number_of_stages=2, number_of_stages=3,
second_stage_batch_size=7, second_stage_batch_size=7,
predict_masks=True, predict_masks=True,
masks_are_class_agnostic=masks_are_class_agnostic) masks_are_class_agnostic=masks_are_class_agnostic)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment