Commit 4f4810f0 authored by Geewook Kim's avatar Geewook Kim
Browse files

feat: add functions to calculate f1 accuracy score

parent 95cde5a9
......@@ -6,6 +6,7 @@ MIT License
import json
import os
import random
from collections import defaultdict
from typing import Any, Dict, List, Tuple, Union
import torch
......@@ -31,7 +32,7 @@ class DonutDataset(Dataset):
"""
DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string).
and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string)
Args:
dataset_name_or_path: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl
......@@ -94,7 +95,7 @@ class DonutDataset(Dataset):
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Load image from image_path of given dataset_path and convert into input_tensor and labels
Load image from image_path of given dataset_path and convert into input_tensor and labels.
Convert gt data into input_ids (tokenized string)
Returns:
......@@ -136,18 +137,50 @@ class DonutDataset(Dataset):
class JSONParseEvaluator:
"""
Calculate n-TED(Normalized Tree Edit Distance) based accuracy between a predicted json and a gold json,
calculated as,
accuracy = 1 - TED(normalize(pred), normalize(gold)) / TED({}, normalize(gold))
Calculate n-TED(Normalized Tree Edit Distance) based accuracy and F1 accuracy score
"""
@staticmethod
def flatten(data: dict):
"""
Convert Dictionary into Non-nested Dictionary
Example:
input(dict)
{
"menu": [
{"name" : ["cake"], "count" : ["2"]},
{"name" : ["juice"], "count" : ["1"]},
]
}
output(dict)
{
"menu.name": ["cake", "juice"],
"menu.count": ["2", "1"],
}
"""
flatten_data = defaultdict(list)
def _flatten(value, key=""):
if type(value) is dict:
for child_key, child_value in value.items():
_flatten(child_value, f"{key}.{child_key}" if key else child_key)
elif type(value) is list:
for value_item in value:
_flatten(value_item, key)
else:
flatten_data[key].append(value)
_flatten(data)
return dict(flatten_data)
@staticmethod
def update_cost(label1: str, label2: str):
"""
Update cost for tree edit distance.
If both are leaf node, calculate string edit distance between two labels (special token '<leaf>' will be ignored).
If one of them is leaf node, cost is length of string in leaf node + 1.
If neither are leaf node, cost is 0 if label1 is same with label2 othewise 1.
If neither are leaf node, cost is 0 if label1 is same with label2 othewise 1
"""
label1_leaf = "<leaf>" in label1
label2_leaf = "<leaf>" in label2
......@@ -161,7 +194,7 @@ class JSONParseEvaluator:
return int(label1 != label2)
@staticmethod
def insert_and_remove_cost(node):
def insert_and_remove_cost(node: Node):
"""
Insert and remove cost for tree edit distance.
If leaf node, cost is length of label name.
......@@ -175,7 +208,7 @@ class JSONParseEvaluator:
def normalize_dict(self, data: Union[Dict, List, Any]):
"""
Sort by value, while iterate over element if data is list.
Sort by value, while iterate over element if data is list
"""
if not data:
return {}
......@@ -203,6 +236,22 @@ class JSONParseEvaluator:
return new_data
def cal_f1(self, preds: List[dict], answers: List[dict]):
"""
Calculate global F1 accuracy score (field-level, micro-averaged) by counting all true positives, false negatives and false positives
"""
total_tp, total_fn_or_fp = 0, 0
for pred, answer in zip(preds, answers):
pred, answer = self.flatten(self.normalize_dict(pred)), self.flatten(self.normalize_dict(answer))
for pred_key, pred_values in pred.items():
for pred_value in pred_values:
if pred_key in answer and pred_value in answer[pred_key]:
answer[pred_key].remove(pred_value)
total_tp += 1
else:
total_fn_or_fp += 1
return total_tp / (total_tp + (total_fn_or_fp) / 2)
def construct_tree_from_dict(self, data: Union[Dict, List], node_name: str = None):
"""
Convert Dictionary into Tree
......@@ -252,7 +301,7 @@ class JSONParseEvaluator:
raise Exception(data, node_name)
return node
def cal_acc(self, pred, answer):
def cal_acc(self, pred: dict, answer: dict):
"""
Calculate normalized tree edit distance(nTED) based accuracy.
1) Construct tree from dict,
......
......@@ -32,9 +32,11 @@ def test(args):
if args.save_path:
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
output_list = []
predictions = []
ground_truths = []
accs = []
evaluator = JSONParseEvaluator()
dataset = load_dataset(args.dataset_name_or_path, split=args.split)
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
......@@ -52,24 +54,35 @@ def test(args):
gt = ground_truth["gt_parse"]
score = float(output["class"] == gt["class"])
elif args.task_name == "docvqa":
score = 0.0 # note: docvqa is evaluated on the official website
# Note: we evaluated the model on the official website.
# In this script, an exact-match based score will be returned instead
gt = ground_truth["gt_parses"]
answers = set([qa_parse["answer"] for qa_parse in gt])
score = float(output["answer"] in answers)
else:
gt = ground_truth["gt_parse"]
evaluator = JSONParseEvaluator()
score = evaluator.cal_acc(output, gt)
accs.append(score)
output_list.append(output)
predictions.append(output)
ground_truths.append(gt)
scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)}
print(scores, f"length : {len(accs)}")
scores = {
"ted_accuracies": accs,
"ted_accuracy": np.mean(accs),
"f1_accuracy": evaluator.cal_f1(predictions, ground_truths),
}
print(
f"Total number of samples: {len(accs)}, Tree Edit Distance (TED) based accuracy score: {scores['ted_accuracy']}, F1 accuracy score: {scores['f1_accuracy']}"
)
if args.save_path:
scores["predictions"] = output_list
scores["predictions"] = predictions
scores["ground_truths"] = ground_truths
save_json(args.save_path, scores)
return output_list
return predictions
if __name__ == "__main__":
......@@ -84,4 +97,4 @@ if __name__ == "__main__":
if args.task_name is None:
args.task_name = os.path.basename(args.dataset_name_or_path)
predicts = test(args)
predictions = test(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment