Commit d7f255f6 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Provide groundtruth masks and keypoints in the evaluator.

PiperOrigin-RevId: 189929413
parent e46021a7
......@@ -94,14 +94,24 @@ def _extract_predictions_and_losses(model,
if fields.InputDataFields.groundtruth_group_of in input_dict:
groundtruth[fields.InputDataFields.groundtruth_group_of] = (
input_dict[fields.InputDataFields.groundtruth_group_of])
groundtruth_masks_list = None
if fields.DetectionResultFields.detection_masks in detections:
groundtruth[fields.InputDataFields.groundtruth_instance_masks] = (
input_dict[fields.InputDataFields.groundtruth_instance_masks])
groundtruth_masks_list = [
input_dict[fields.InputDataFields.groundtruth_instance_masks]]
groundtruth_keypoints_list = None
if fields.DetectionResultFields.detection_keypoints in detections:
groundtruth[fields.InputDataFields.groundtruth_keypoints] = (
input_dict[fields.InputDataFields.groundtruth_keypoints])
groundtruth_keypoints_list = [
input_dict[fields.InputDataFields.groundtruth_keypoints]]
label_id_offset = 1
model.provide_groundtruth(
[input_dict[fields.InputDataFields.groundtruth_boxes]],
[tf.one_hot(input_dict[fields.InputDataFields.groundtruth_classes]
- label_id_offset, depth=model.num_classes)])
- label_id_offset, depth=model.num_classes)],
groundtruth_masks_list, groundtruth_keypoints_list)
losses_dict.update(model.loss(prediction_dict, true_image_shapes))
result_dict = eval_util.result_dict_for_single_example(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment