Commit a4d4d0f4 authored by syiming's avatar syiming
Browse files

pass image shape into _compute_second_stage_input_feature_map

parent 47b575bb
...@@ -453,7 +453,6 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -453,7 +453,6 @@ class FasterRCNNMetaArch(model.DetectionModel):
self._is_training = is_training self._is_training = is_training
self._image_resizer_fn = image_resizer_fn self._image_resizer_fn = image_resizer_fn
self._resize_shape = None
self._resize_masks = resize_masks self._resize_masks = resize_masks
self._feature_extractor = feature_extractor self._feature_extractor = feature_extractor
if isinstance(feature_extractor, FasterRCNNKerasFeatureExtractor): if isinstance(feature_extractor, FasterRCNNKerasFeatureExtractor):
...@@ -688,8 +687,6 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -688,8 +687,6 @@ class FasterRCNNMetaArch(model.DetectionModel):
true_image_shapes) = shape_utils.resize_images_and_return_shapes( true_image_shapes) = shape_utils.resize_images_and_return_shapes(
inputs, self._image_resizer_fn) inputs, self._image_resizer_fn)
self._resize_shape = resized_inputs.shape.as_list()
return (self._feature_extractor.preprocess(resized_inputs), return (self._feature_extractor.preprocess(resized_inputs),
true_image_shapes) true_image_shapes)
...@@ -1060,7 +1057,8 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1060,7 +1057,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
""" """
flattened_proposal_feature_maps = ( flattened_proposal_feature_maps = (
self._compute_second_stage_input_feature_maps( self._compute_second_stage_input_feature_maps(
rpn_features_to_crop, proposal_boxes_normalized, **side_inputs)) rpn_features_to_crop, proposal_boxes_normalized,
image_shape, **side_inputs))
box_classifier_features = self._extract_box_classifier_features( box_classifier_features = self._extract_box_classifier_features(
flattened_proposal_feature_maps) flattened_proposal_feature_maps)
...@@ -1192,6 +1190,8 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1192,6 +1190,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
decoded proposal bounding boxes in absolute coordinates. decoded proposal bounding boxes in absolute coordinates.
5) box_classifier_features: a 4-D float32 tensor representing the 5) box_classifier_features: a 4-D float32 tensor representing the
features for each proposal. features for each proposal.
6) image_shape: a 1-D tensor of shape [4] representing the input
image shape.
image_shapes: A 2-D int32 tensors of shape [batch_size, 3] containing image_shapes: A 2-D int32 tensors of shape [batch_size, 3] containing
shapes of images in the batch. shapes of images in the batch.
...@@ -1230,11 +1230,12 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1230,11 +1230,12 @@ class FasterRCNNMetaArch(model.DetectionModel):
detection_classes = detections_dict[ detection_classes = detections_dict[
fields.DetectionResultFields.detection_classes] fields.DetectionResultFields.detection_classes]
rpn_features_to_crop = prediction_dict['rpn_features_to_crop'] rpn_features_to_crop = prediction_dict['rpn_features_to_crop']
image_shape = prediction_dict['image_shape']
batch_size = tf.shape(detection_boxes)[0] batch_size = tf.shape(detection_boxes)[0]
max_detection = tf.shape(detection_boxes)[1] max_detection = tf.shape(detection_boxes)[1]
flattened_detected_feature_maps = ( flattened_detected_feature_maps = (
self._compute_second_stage_input_feature_maps( self._compute_second_stage_input_feature_maps(
rpn_features_to_crop, detection_boxes)) rpn_features_to_crop, detection_boxes, image_shape))
curr_box_classifier_features = self._extract_box_classifier_features( curr_box_classifier_features = self._extract_box_classifier_features(
flattened_detected_feature_maps) flattened_detected_feature_maps)
...@@ -1549,7 +1550,8 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1549,7 +1550,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
detections_dict[ detections_dict[
'detection_features'] = self._add_detection_features_output_node( 'detection_features'] = self._add_detection_features_output_node(
detections_dict[fields.DetectionResultFields.detection_boxes], detections_dict[fields.DetectionResultFields.detection_boxes],
prediction_dict['rpn_features_to_crop']) prediction_dict['rpn_features_to_crop'],
prediction_dict['image_shape'])
return detections_dict return detections_dict
...@@ -1566,7 +1568,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1566,7 +1568,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
return prediction_dict return prediction_dict
def _add_detection_features_output_node(self, detection_boxes, def _add_detection_features_output_node(self, detection_boxes,
rpn_features_to_crop): rpn_features_to_crop, image_shape):
"""Add detection features to outputs. """Add detection features to outputs.
This function extracts box features for each box in rpn_features_to_crop. This function extracts box features for each box in rpn_features_to_crop.
...@@ -1581,6 +1583,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1581,6 +1583,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
rpn_features_to_crop: A list of 4-D float32 tensor with shape rpn_features_to_crop: A list of 4-D float32 tensor with shape
[batch, height, width, depth] representing image features to crop using [batch, height, width, depth] representing image features to crop using
the proposals boxes. the proposals boxes.
image_shape: a 1-D tensor of shape [4] representing the image shape.
Returns: Returns:
detection_features: a 4-D float32 tensor of shape detection_features: a 4-D float32 tensor of shape
...@@ -1590,7 +1593,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1590,7 +1593,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
with tf.name_scope('SecondStageDetectionFeaturesExtract'): with tf.name_scope('SecondStageDetectionFeaturesExtract'):
flattened_detected_feature_maps = ( flattened_detected_feature_maps = (
self._compute_second_stage_input_feature_maps( self._compute_second_stage_input_feature_maps(
rpn_features_to_crop, detection_boxes)) rpn_features_to_crop, detection_boxes, image_shape))
detection_features_unpooled = self._extract_box_classifier_features( detection_features_unpooled = self._extract_box_classifier_features(
flattened_detected_feature_maps) flattened_detected_feature_maps)
...@@ -1932,6 +1935,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1932,6 +1935,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
def _compute_second_stage_input_feature_maps(self, features_to_crop, def _compute_second_stage_input_feature_maps(self, features_to_crop,
proposal_boxes_normalized, proposal_boxes_normalized,
image_shape,
**side_inputs): **side_inputs):
"""Crops to a set of proposals from the feature map for a batch of images. """Crops to a set of proposals from the feature map for a batch of images.
...@@ -1945,6 +1949,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1945,6 +1949,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
proposal_boxes_normalized: A float32 tensor with shape [batch_size, proposal_boxes_normalized: A float32 tensor with shape [batch_size,
num_proposals, box_code_size] containing proposal boxes in num_proposals, box_code_size] containing proposal boxes in
normalized coordinates. normalized coordinates.
image_shape: A 1D int32 tensors of size [4] containing the image shape.
**side_inputs: additional tensors that are required by the network. **side_inputs: additional tensors that are required by the network.
Returns: Returns:
...@@ -1957,9 +1962,11 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1957,9 +1962,11 @@ class FasterRCNNMetaArch(model.DetectionModel):
# unit_scale_index: num_levels-2 is chosen based on section 4.2 of # unit_scale_index: num_levels-2 is chosen based on section 4.2 of
# https://arxiv.org/pdf/1612.03144.pdf and works best for Resnet based # https://arxiv.org/pdf/1612.03144.pdf and works best for Resnet based
# feature extractor. # feature extractor.
image_shape = image_shape.as_list()
box_levels = ops.fpn_feature_levels( box_levels = ops.fpn_feature_levels(
num_levels, num_levels - 2, num_levels, num_levels - 2,
tf.sqrt(self._resize_shape[1] * self._resize_shape[2] * 1.0) / 224.0,
tf.sqrt(image_shape[1] * image_shape[2] * 1.0) / 224.0,
proposal_boxes_normalized) proposal_boxes_normalized)
cropped_regions = self._flatten_first_two_dimensions( cropped_regions = self._flatten_first_two_dimensions(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment