"examples/vscode:/vscode.git/clone" did not exist on "e58eeebfb68e42871f60cc0702b35d1110653d80"
Commit d64b4616 authored by Zhichao Lu's avatar Zhichao Lu Committed by lzc5123016
Browse files

Make sure boxes, scores, masks and classes have the same length.

PiperOrigin-RevId: 188061451
parent c21d3c25
......@@ -487,15 +487,12 @@ def result_dict_for_single_example(image,
detection_fields = fields.DetectionResultFields
detection_boxes = detections[detection_fields.detection_boxes][0]
output_dict[detection_fields.detection_boxes] = detection_boxes
image_shape = tf.shape(image)
if scale_to_absolute:
absolute_detection_boxlist = box_list_ops.to_absolute_coordinates(
box_list.BoxList(detection_boxes), image_shape[1], image_shape[2])
output_dict[detection_fields.detection_boxes] = (
absolute_detection_boxlist.get())
detection_boxes = absolute_detection_boxlist.get()
detection_scores = detections[detection_fields.detection_scores][0]
output_dict[detection_fields.detection_scores] = detection_scores
if class_agnostic:
detection_classes = tf.ones_like(detection_scores, dtype=tf.int64)
......@@ -503,15 +500,22 @@ def result_dict_for_single_example(image,
detection_classes = (
tf.to_int64(detections[detection_fields.detection_classes][0]) +
label_id_offset)
num_detections = tf.to_int32(detections[detection_fields.num_detections][0])
detection_boxes = tf.slice(
detection_boxes, begin=[0, 0], size=[num_detections, -1])
detection_classes = tf.slice(
detection_classes, begin=[0], size=[num_detections])
detection_scores = tf.slice(
detection_scores, begin=[0], size=[num_detections])
output_dict[detection_fields.detection_boxes] = detection_boxes
output_dict[detection_fields.detection_classes] = detection_classes
output_dict[detection_fields.detection_scores] = detection_scores
if detection_fields.detection_masks in detections:
detection_masks = detections[detection_fields.detection_masks][0]
# TODO(rathodv): This should be done in model's postprocess
# function ideally.
num_detections = tf.to_int32(detections[detection_fields.num_detections][0])
detection_boxes = tf.slice(
detection_boxes, begin=[0, 0], size=[num_detections, -1])
detection_masks = tf.slice(
detection_masks, begin=[0, 0, 0], size=[num_detections, -1, -1])
detection_masks_reframed = ops.reframe_box_masks_to_image_masks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment