Commit d64b4616 authored by Zhichao Lu's avatar Zhichao Lu Committed by lzc5123016
Browse files

Make sure boxes, scores, masks and classes have the same length.

PiperOrigin-RevId: 188061451
parent c21d3c25
......@@ -487,15 +487,12 @@ def result_dict_for_single_example(image,
detection_fields = fields.DetectionResultFields
detection_boxes = detections[detection_fields.detection_boxes][0]
output_dict[detection_fields.detection_boxes] = detection_boxes
image_shape = tf.shape(image)
if scale_to_absolute:
absolute_detection_boxlist = box_list_ops.to_absolute_coordinates(
box_list.BoxList(detection_boxes), image_shape[1], image_shape[2])
output_dict[detection_fields.detection_boxes] = (
absolute_detection_boxlist.get())
detection_boxes = absolute_detection_boxlist.get()
detection_scores = detections[detection_fields.detection_scores][0]
output_dict[detection_fields.detection_scores] = detection_scores
if class_agnostic:
detection_classes = tf.ones_like(detection_scores, dtype=tf.int64)
......@@ -503,15 +500,22 @@ def result_dict_for_single_example(image,
detection_classes = (
tf.to_int64(detections[detection_fields.detection_classes][0]) +
label_id_offset)
num_detections = tf.to_int32(detections[detection_fields.num_detections][0])
detection_boxes = tf.slice(
detection_boxes, begin=[0, 0], size=[num_detections, -1])
detection_classes = tf.slice(
detection_classes, begin=[0], size=[num_detections])
detection_scores = tf.slice(
detection_scores, begin=[0], size=[num_detections])
output_dict[detection_fields.detection_boxes] = detection_boxes
output_dict[detection_fields.detection_classes] = detection_classes
output_dict[detection_fields.detection_scores] = detection_scores
if detection_fields.detection_masks in detections:
detection_masks = detections[detection_fields.detection_masks][0]
# TODO(rathodv): This should be done in model's postprocess
# function ideally.
num_detections = tf.to_int32(detections[detection_fields.num_detections][0])
detection_boxes = tf.slice(
detection_boxes, begin=[0, 0], size=[num_detections, -1])
detection_masks = tf.slice(
detection_masks, begin=[0, 0, 0], size=[num_detections, -1, -1])
detection_masks_reframed = ops.reframe_box_masks_to_image_masks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment