"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "912fdff899cf0fd674ed357e46a0209311aefad2"
Commit 86bcafe1 authored by Geewook Kim's avatar Geewook Kim
Browse files

feat: update JSONParseEvaluator

parent d2fd95a3
...@@ -137,14 +137,13 @@ class DonutDataset(Dataset): ...@@ -137,14 +137,13 @@ class DonutDataset(Dataset):
class JSONParseEvaluator: class JSONParseEvaluator:
""" """
Calculate n-TED(Normalized Tree Edit Distance) based accuracy and F1 accuracy score Calculate n-TED(Normalized Tree Edit Distance) based accuracy and F1 accuracy score
""" """
@staticmethod @staticmethod
def flatten(data: dict): def flatten(data: dict):
""" """
Convert Dictionary into Non-nested Dictionary Convert Dictionary into Non-nested Dictionary
Example: Example:
input(dict) input(dict)
{ {
...@@ -153,13 +152,15 @@ class JSONParseEvaluator: ...@@ -153,13 +152,15 @@ class JSONParseEvaluator:
{"name" : ["juice"], "count" : ["1"]}, {"name" : ["juice"], "count" : ["1"]},
] ]
} }
output(dict) output(list)
{ [
"menu.name": ["cake", "juice"], ("menu.name", "cake"),
"menu.count": ["2", "1"], ("menu.count", "2"),
} ("menu.name", "juice"),
("menu.count", "1"),
]
""" """
flatten_data = defaultdict(list) flatten_data = list()
def _flatten(value, key=""): def _flatten(value, key=""):
if type(value) is dict: if type(value) is dict:
...@@ -169,10 +170,10 @@ class JSONParseEvaluator: ...@@ -169,10 +170,10 @@ class JSONParseEvaluator:
for value_item in value: for value_item in value:
_flatten(value_item, key) _flatten(value_item, key)
else: else:
flatten_data[key].append(value) flatten_data.append((key, value))
_flatten(data) _flatten(data)
return dict(flatten_data) return flatten_data
@staticmethod @staticmethod
def update_cost(label1: str, label2: str): def update_cost(label1: str, label2: str):
...@@ -225,10 +226,11 @@ class JSONParseEvaluator: ...@@ -225,10 +226,11 @@ class JSONParseEvaluator:
elif isinstance(data, list): elif isinstance(data, list):
if all(isinstance(item, dict) for item in data): if all(isinstance(item, dict) for item in data):
new_data = [] new_data = []
for item in sorted(data, key=lambda x: str(sorted(x.items()))): for item in data:
item = self.normalize_dict(item) item = self.normalize_dict(item)
if item: if item:
new_data.append(item) new_data.append(item)
new_data = sorted(new_data, key=lambda x: str(x.keys())+str(x.values()))
else: else:
new_data = sorted([str(item) for item in data if type(item) in {str, int, float} and str(item)]) new_data = sorted([str(item) for item in data if type(item) in {str, int, float} and str(item)])
else: else:
...@@ -243,14 +245,14 @@ class JSONParseEvaluator: ...@@ -243,14 +245,14 @@ class JSONParseEvaluator:
total_tp, total_fn_or_fp = 0, 0 total_tp, total_fn_or_fp = 0, 0
for pred, answer in zip(preds, answers): for pred, answer in zip(preds, answers):
pred, answer = self.flatten(self.normalize_dict(pred)), self.flatten(self.normalize_dict(answer)) pred, answer = self.flatten(self.normalize_dict(pred)), self.flatten(self.normalize_dict(answer))
for pred_key, pred_values in pred.items(): for field in pred:
for pred_value in pred_values: if field in answer:
if pred_key in answer and pred_value in answer[pred_key]: total_tp += 1
answer[pred_key].remove(pred_value) answer.remove(field)
total_tp += 1 else:
else: total_fn_or_fp += 1
total_fn_or_fp += 1 total_fn_or_fp += len(answer)
return total_tp / (total_tp + (total_fn_or_fp) / 2) return total_tp / (total_tp + total_fn_or_fp / 2)
def construct_tree_from_dict(self, data: Union[Dict, List], node_name: str = None): def construct_tree_from_dict(self, data: Union[Dict, List], node_name: str = None):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment