"examples/pytorch/vscode:/vscode.git/clone" did not exist on "d090ae865c76902849dfbf1286307fb8a2710cc4"
Commit d7f255f6 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Provide groundtruth masks and keypoints in the evaluator.

PiperOrigin-RevId: 189929413
parent e46021a7
...@@ -94,14 +94,24 @@ def _extract_predictions_and_losses(model, ...@@ -94,14 +94,24 @@ def _extract_predictions_and_losses(model,
if fields.InputDataFields.groundtruth_group_of in input_dict: if fields.InputDataFields.groundtruth_group_of in input_dict:
groundtruth[fields.InputDataFields.groundtruth_group_of] = ( groundtruth[fields.InputDataFields.groundtruth_group_of] = (
input_dict[fields.InputDataFields.groundtruth_group_of]) input_dict[fields.InputDataFields.groundtruth_group_of])
groundtruth_masks_list = None
if fields.DetectionResultFields.detection_masks in detections: if fields.DetectionResultFields.detection_masks in detections:
groundtruth[fields.InputDataFields.groundtruth_instance_masks] = ( groundtruth[fields.InputDataFields.groundtruth_instance_masks] = (
input_dict[fields.InputDataFields.groundtruth_instance_masks]) input_dict[fields.InputDataFields.groundtruth_instance_masks])
groundtruth_masks_list = [
input_dict[fields.InputDataFields.groundtruth_instance_masks]]
groundtruth_keypoints_list = None
if fields.DetectionResultFields.detection_keypoints in detections:
groundtruth[fields.InputDataFields.groundtruth_keypoints] = (
input_dict[fields.InputDataFields.groundtruth_keypoints])
groundtruth_keypoints_list = [
input_dict[fields.InputDataFields.groundtruth_keypoints]]
label_id_offset = 1 label_id_offset = 1
model.provide_groundtruth( model.provide_groundtruth(
[input_dict[fields.InputDataFields.groundtruth_boxes]], [input_dict[fields.InputDataFields.groundtruth_boxes]],
[tf.one_hot(input_dict[fields.InputDataFields.groundtruth_classes] [tf.one_hot(input_dict[fields.InputDataFields.groundtruth_classes]
- label_id_offset, depth=model.num_classes)]) - label_id_offset, depth=model.num_classes)],
groundtruth_masks_list, groundtruth_keypoints_list)
losses_dict.update(model.loss(prediction_dict, true_image_shapes)) losses_dict.update(model.loss(prediction_dict, true_image_shapes))
result_dict = eval_util.result_dict_for_single_example( result_dict = eval_util.result_dict_for_single_example(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment