Commit d64b4616 authored by Zhichao Lu's avatar Zhichao Lu Committed by lzc5123016
Browse files

Make sure boxes, scores, masks and classes have the same length.

PiperOrigin-RevId: 188061451
parent c21d3c25
...@@ -487,15 +487,12 @@ def result_dict_for_single_example(image, ...@@ -487,15 +487,12 @@ def result_dict_for_single_example(image,
detection_fields = fields.DetectionResultFields detection_fields = fields.DetectionResultFields
detection_boxes = detections[detection_fields.detection_boxes][0] detection_boxes = detections[detection_fields.detection_boxes][0]
output_dict[detection_fields.detection_boxes] = detection_boxes
image_shape = tf.shape(image) image_shape = tf.shape(image)
if scale_to_absolute: if scale_to_absolute:
absolute_detection_boxlist = box_list_ops.to_absolute_coordinates( absolute_detection_boxlist = box_list_ops.to_absolute_coordinates(
box_list.BoxList(detection_boxes), image_shape[1], image_shape[2]) box_list.BoxList(detection_boxes), image_shape[1], image_shape[2])
output_dict[detection_fields.detection_boxes] = ( detection_boxes = absolute_detection_boxlist.get()
absolute_detection_boxlist.get())
detection_scores = detections[detection_fields.detection_scores][0] detection_scores = detections[detection_fields.detection_scores][0]
output_dict[detection_fields.detection_scores] = detection_scores
if class_agnostic: if class_agnostic:
detection_classes = tf.ones_like(detection_scores, dtype=tf.int64) detection_classes = tf.ones_like(detection_scores, dtype=tf.int64)
...@@ -503,15 +500,22 @@ def result_dict_for_single_example(image, ...@@ -503,15 +500,22 @@ def result_dict_for_single_example(image,
detection_classes = ( detection_classes = (
tf.to_int64(detections[detection_fields.detection_classes][0]) + tf.to_int64(detections[detection_fields.detection_classes][0]) +
label_id_offset) label_id_offset)
num_detections = tf.to_int32(detections[detection_fields.num_detections][0])
detection_boxes = tf.slice(
detection_boxes, begin=[0, 0], size=[num_detections, -1])
detection_classes = tf.slice(
detection_classes, begin=[0], size=[num_detections])
detection_scores = tf.slice(
detection_scores, begin=[0], size=[num_detections])
output_dict[detection_fields.detection_boxes] = detection_boxes
output_dict[detection_fields.detection_classes] = detection_classes output_dict[detection_fields.detection_classes] = detection_classes
output_dict[detection_fields.detection_scores] = detection_scores
if detection_fields.detection_masks in detections: if detection_fields.detection_masks in detections:
detection_masks = detections[detection_fields.detection_masks][0] detection_masks = detections[detection_fields.detection_masks][0]
# TODO(rathodv): This should be done in model's postprocess # TODO(rathodv): This should be done in model's postprocess
# function ideally. # function ideally.
num_detections = tf.to_int32(detections[detection_fields.num_detections][0])
detection_boxes = tf.slice(
detection_boxes, begin=[0, 0], size=[num_detections, -1])
detection_masks = tf.slice( detection_masks = tf.slice(
detection_masks, begin=[0, 0, 0], size=[num_detections, -1, -1]) detection_masks, begin=[0, 0, 0], size=[num_detections, -1, -1])
detection_masks_reframed = ops.reframe_box_masks_to_image_masks( detection_masks_reframed = ops.reframe_box_masks_to_image_masks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment