Commit 4f4810f0 authored by Geewook Kim's avatar Geewook Kim
Browse files

feat: add functions to calculate f1 accuracy score

parent 95cde5a9
...@@ -6,6 +6,7 @@ MIT License ...@@ -6,6 +6,7 @@ MIT License
import json import json
import os import os
import random import random
from collections import defaultdict
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
import torch import torch
...@@ -31,7 +32,7 @@ class DonutDataset(Dataset): ...@@ -31,7 +32,7 @@ class DonutDataset(Dataset):
""" """
DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets) DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt), Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string). and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string)
Args: Args:
dataset_name_or_path: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl dataset_name_or_path: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl
...@@ -94,7 +95,7 @@ class DonutDataset(Dataset): ...@@ -94,7 +95,7 @@ class DonutDataset(Dataset):
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Load image from image_path of given dataset_path and convert into input_tensor and labels Load image from image_path of given dataset_path and convert into input_tensor and labels.
Convert gt data into input_ids (tokenized string) Convert gt data into input_ids (tokenized string)
Returns: Returns:
...@@ -136,18 +137,50 @@ class DonutDataset(Dataset): ...@@ -136,18 +137,50 @@ class DonutDataset(Dataset):
class JSONParseEvaluator: class JSONParseEvaluator:
""" """
Calculate n-TED(Normalized Tree Edit Distance) based accuracy between a predicted json and a gold json, Calculate n-TED(Normalized Tree Edit Distance) based accuracy and F1 accuracy score
calculated as,
accuracy = 1 - TED(normalize(pred), normalize(gold)) / TED({}, normalize(gold))
""" """
@staticmethod
def flatten(data: dict):
"""
Convert Dictionary into Non-nested Dictionary
Example:
input(dict)
{
"menu": [
{"name" : ["cake"], "count" : ["2"]},
{"name" : ["juice"], "count" : ["1"]},
]
}
output(dict)
{
"menu.name": ["cake", "juice"],
"menu.count": ["2", "1"],
}
"""
flatten_data = defaultdict(list)
def _flatten(value, key=""):
if type(value) is dict:
for child_key, child_value in value.items():
_flatten(child_value, f"{key}.{child_key}" if key else child_key)
elif type(value) is list:
for value_item in value:
_flatten(value_item, key)
else:
flatten_data[key].append(value)
_flatten(data)
return dict(flatten_data)
@staticmethod @staticmethod
def update_cost(label1: str, label2: str): def update_cost(label1: str, label2: str):
""" """
Update cost for tree edit distance. Update cost for tree edit distance.
If both are leaf node, calculate string edit distance between two labels (special token '<leaf>' will be ignored). If both are leaf node, calculate string edit distance between two labels (special token '<leaf>' will be ignored).
If one of them is leaf node, cost is length of string in leaf node + 1. If one of them is leaf node, cost is length of string in leaf node + 1.
If neither are leaf node, cost is 0 if label1 is same with label2 othewise 1. If neither are leaf node, cost is 0 if label1 is same with label2 othewise 1
""" """
label1_leaf = "<leaf>" in label1 label1_leaf = "<leaf>" in label1
label2_leaf = "<leaf>" in label2 label2_leaf = "<leaf>" in label2
...@@ -161,7 +194,7 @@ class JSONParseEvaluator: ...@@ -161,7 +194,7 @@ class JSONParseEvaluator:
return int(label1 != label2) return int(label1 != label2)
@staticmethod @staticmethod
def insert_and_remove_cost(node): def insert_and_remove_cost(node: Node):
""" """
Insert and remove cost for tree edit distance. Insert and remove cost for tree edit distance.
If leaf node, cost is length of label name. If leaf node, cost is length of label name.
...@@ -175,7 +208,7 @@ class JSONParseEvaluator: ...@@ -175,7 +208,7 @@ class JSONParseEvaluator:
def normalize_dict(self, data: Union[Dict, List, Any]): def normalize_dict(self, data: Union[Dict, List, Any]):
""" """
Sort by value, while iterate over element if data is list. Sort by value, while iterate over element if data is list
""" """
if not data: if not data:
return {} return {}
...@@ -203,6 +236,22 @@ class JSONParseEvaluator: ...@@ -203,6 +236,22 @@ class JSONParseEvaluator:
return new_data return new_data
def cal_f1(self, preds: List[dict], answers: List[dict]):
"""
Calculate global F1 accuracy score (field-level, micro-averaged) by counting all true positives, false negatives and false positives
"""
total_tp, total_fn_or_fp = 0, 0
for pred, answer in zip(preds, answers):
pred, answer = self.flatten(self.normalize_dict(pred)), self.flatten(self.normalize_dict(answer))
for pred_key, pred_values in pred.items():
for pred_value in pred_values:
if pred_key in answer and pred_value in answer[pred_key]:
answer[pred_key].remove(pred_value)
total_tp += 1
else:
total_fn_or_fp += 1
return total_tp / (total_tp + (total_fn_or_fp) / 2)
def construct_tree_from_dict(self, data: Union[Dict, List], node_name: str = None): def construct_tree_from_dict(self, data: Union[Dict, List], node_name: str = None):
""" """
Convert Dictionary into Tree Convert Dictionary into Tree
...@@ -252,7 +301,7 @@ class JSONParseEvaluator: ...@@ -252,7 +301,7 @@ class JSONParseEvaluator:
raise Exception(data, node_name) raise Exception(data, node_name)
return node return node
def cal_acc(self, pred, answer): def cal_acc(self, pred: dict, answer: dict):
""" """
Calculate normalized tree edit distance(nTED) based accuracy. Calculate normalized tree edit distance(nTED) based accuracy.
1) Construct tree from dict, 1) Construct tree from dict,
......
...@@ -32,9 +32,11 @@ def test(args): ...@@ -32,9 +32,11 @@ def test(args):
if args.save_path: if args.save_path:
os.makedirs(os.path.dirname(args.save_path), exist_ok=True) os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
output_list = [] predictions = []
ground_truths = []
accs = [] accs = []
evaluator = JSONParseEvaluator()
dataset = load_dataset(args.dataset_name_or_path, split=args.split) dataset = load_dataset(args.dataset_name_or_path, split=args.split)
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)): for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
...@@ -52,24 +54,35 @@ def test(args): ...@@ -52,24 +54,35 @@ def test(args):
gt = ground_truth["gt_parse"] gt = ground_truth["gt_parse"]
score = float(output["class"] == gt["class"]) score = float(output["class"] == gt["class"])
elif args.task_name == "docvqa": elif args.task_name == "docvqa":
score = 0.0 # note: docvqa is evaluated on the official website # Note: we evaluated the model on the official website.
# In this script, an exact-match based score will be returned instead
gt = ground_truth["gt_parses"]
answers = set([qa_parse["answer"] for qa_parse in gt])
score = float(output["answer"] in answers)
else: else:
gt = ground_truth["gt_parse"] gt = ground_truth["gt_parse"]
evaluator = JSONParseEvaluator()
score = evaluator.cal_acc(output, gt) score = evaluator.cal_acc(output, gt)
accs.append(score) accs.append(score)
output_list.append(output) predictions.append(output)
ground_truths.append(gt)
scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)} scores = {
print(scores, f"length : {len(accs)}") "ted_accuracies": accs,
"ted_accuracy": np.mean(accs),
"f1_accuracy": evaluator.cal_f1(predictions, ground_truths),
}
print(
f"Total number of samples: {len(accs)}, Tree Edit Distance (TED) based accuracy score: {scores['ted_accuracy']}, F1 accuracy score: {scores['f1_accuracy']}"
)
if args.save_path: if args.save_path:
scores["predictions"] = output_list scores["predictions"] = predictions
scores["ground_truths"] = ground_truths
save_json(args.save_path, scores) save_json(args.save_path, scores)
return output_list return predictions
if __name__ == "__main__": if __name__ == "__main__":
...@@ -84,4 +97,4 @@ if __name__ == "__main__": ...@@ -84,4 +97,4 @@ if __name__ == "__main__":
if args.task_name is None: if args.task_name is None:
args.task_name = os.path.basename(args.dataset_name_or_path) args.task_name = os.path.basename(args.dataset_name_or_path)
predicts = test(args) predictions = test(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment