Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
donut_pytorch
Commits
4f4810f0
Commit
4f4810f0
authored
Aug 19, 2022
by
Geewook Kim
Browse files
feat: add functions to calculate f1 accuracy score
parent
95cde5a9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
18 deletions
+80
-18
donut/util.py
donut/util.py
+58
-9
test.py
test.py
+22
-9
No files found.
donut/util.py
View file @
4f4810f0
...
...
@@ -6,6 +6,7 @@ MIT License
import
json
import
os
import
random
from
collections
import
defaultdict
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
import
torch
...
...
@@ -31,7 +32,7 @@ class DonutDataset(Dataset):
"""
DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string)
.
and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string)
Args:
dataset_name_or_path: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl
...
...
@@ -94,7 +95,7 @@ class DonutDataset(Dataset):
def
__getitem__
(
self
,
idx
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Load image from image_path of given dataset_path and convert into input_tensor and labels
Load image from image_path of given dataset_path and convert into input_tensor and labels
.
Convert gt data into input_ids (tokenized string)
Returns:
...
...
@@ -136,18 +137,50 @@ class DonutDataset(Dataset):
class
JSONParseEvaluator
:
"""
Calculate n-TED(Normalized Tree Edit Distance) based accuracy between a predicted json and a gold json,
calculated as,
accuracy = 1 - TED(normalize(pred), normalize(gold)) / TED({}, normalize(gold))
Calculate n-TED(Normalized Tree Edit Distance) based accuracy and F1 accuracy score
"""
@
staticmethod
def
flatten
(
data
:
dict
):
"""
Convert Dictionary into Non-nested Dictionary
Example:
input(dict)
{
"menu": [
{"name" : ["cake"], "count" : ["2"]},
{"name" : ["juice"], "count" : ["1"]},
]
}
output(dict)
{
"menu.name": ["cake", "juice"],
"menu.count": ["2", "1"],
}
"""
flatten_data
=
defaultdict
(
list
)
def
_flatten
(
value
,
key
=
""
):
if
type
(
value
)
is
dict
:
for
child_key
,
child_value
in
value
.
items
():
_flatten
(
child_value
,
f
"
{
key
}
.
{
child_key
}
"
if
key
else
child_key
)
elif
type
(
value
)
is
list
:
for
value_item
in
value
:
_flatten
(
value_item
,
key
)
else
:
flatten_data
[
key
].
append
(
value
)
_flatten
(
data
)
return
dict
(
flatten_data
)
@
staticmethod
def
update_cost
(
label1
:
str
,
label2
:
str
):
"""
Update cost for tree edit distance.
If both are leaf node, calculate string edit distance between two labels (special token '<leaf>' will be ignored).
If one of them is leaf node, cost is length of string in leaf node + 1.
If neither are leaf node, cost is 0 if label1 is same with label2 othewise 1
.
If neither are leaf node, cost is 0 if label1 is same with label2 othewise 1
"""
label1_leaf
=
"<leaf>"
in
label1
label2_leaf
=
"<leaf>"
in
label2
...
...
@@ -161,7 +194,7 @@ class JSONParseEvaluator:
return
int
(
label1
!=
label2
)
@
staticmethod
def
insert_and_remove_cost
(
node
):
def
insert_and_remove_cost
(
node
:
Node
):
"""
Insert and remove cost for tree edit distance.
If leaf node, cost is length of label name.
...
...
@@ -175,7 +208,7 @@ class JSONParseEvaluator:
def
normalize_dict
(
self
,
data
:
Union
[
Dict
,
List
,
Any
]):
"""
Sort by value, while iterate over element if data is list
.
Sort by value, while iterate over element if data is list
"""
if
not
data
:
return
{}
...
...
@@ -203,6 +236,22 @@ class JSONParseEvaluator:
return
new_data
def
cal_f1
(
self
,
preds
:
List
[
dict
],
answers
:
List
[
dict
]):
"""
Calculate global F1 accuracy score (field-level, micro-averaged) by counting all true positives, false negatives and false positives
"""
total_tp
,
total_fn_or_fp
=
0
,
0
for
pred
,
answer
in
zip
(
preds
,
answers
):
pred
,
answer
=
self
.
flatten
(
self
.
normalize_dict
(
pred
)),
self
.
flatten
(
self
.
normalize_dict
(
answer
))
for
pred_key
,
pred_values
in
pred
.
items
():
for
pred_value
in
pred_values
:
if
pred_key
in
answer
and
pred_value
in
answer
[
pred_key
]:
answer
[
pred_key
].
remove
(
pred_value
)
total_tp
+=
1
else
:
total_fn_or_fp
+=
1
return
total_tp
/
(
total_tp
+
(
total_fn_or_fp
)
/
2
)
def
construct_tree_from_dict
(
self
,
data
:
Union
[
Dict
,
List
],
node_name
:
str
=
None
):
"""
Convert Dictionary into Tree
...
...
@@ -252,7 +301,7 @@ class JSONParseEvaluator:
raise
Exception
(
data
,
node_name
)
return
node
def
cal_acc
(
self
,
pred
,
answer
):
def
cal_acc
(
self
,
pred
:
dict
,
answer
:
dict
):
"""
Calculate normalized tree edit distance(nTED) based accuracy.
1) Construct tree from dict,
...
...
test.py
View file @
4f4810f0
...
...
@@ -32,9 +32,11 @@ def test(args):
if
args
.
save_path
:
os
.
makedirs
(
os
.
path
.
dirname
(
args
.
save_path
),
exist_ok
=
True
)
output_list
=
[]
predictions
=
[]
ground_truths
=
[]
accs
=
[]
evaluator
=
JSONParseEvaluator
()
dataset
=
load_dataset
(
args
.
dataset_name_or_path
,
split
=
args
.
split
)
for
idx
,
sample
in
tqdm
(
enumerate
(
dataset
),
total
=
len
(
dataset
)):
...
...
@@ -52,24 +54,35 @@ def test(args):
gt
=
ground_truth
[
"gt_parse"
]
score
=
float
(
output
[
"class"
]
==
gt
[
"class"
])
elif
args
.
task_name
==
"docvqa"
:
score
=
0.0
# note: docvqa is evaluated on the official website
# Note: we evaluated the model on the official website.
# In this script, an exact-match based score will be returned instead
gt
=
ground_truth
[
"gt_parses"
]
answers
=
set
([
qa_parse
[
"answer"
]
for
qa_parse
in
gt
])
score
=
float
(
output
[
"answer"
]
in
answers
)
else
:
gt
=
ground_truth
[
"gt_parse"
]
evaluator
=
JSONParseEvaluator
()
score
=
evaluator
.
cal_acc
(
output
,
gt
)
accs
.
append
(
score
)
output_list
.
append
(
output
)
predictions
.
append
(
output
)
ground_truths
.
append
(
gt
)
scores
=
{
"accuracies"
:
accs
,
"mean_accuracy"
:
np
.
mean
(
accs
)}
print
(
scores
,
f
"length :
{
len
(
accs
)
}
"
)
scores
=
{
"ted_accuracies"
:
accs
,
"ted_accuracy"
:
np
.
mean
(
accs
),
"f1_accuracy"
:
evaluator
.
cal_f1
(
predictions
,
ground_truths
),
}
print
(
f
"Total number of samples:
{
len
(
accs
)
}
, Tree Edit Distance (TED) based accuracy score:
{
scores
[
'ted_accuracy'
]
}
, F1 accuracy score:
{
scores
[
'f1_accuracy'
]
}
"
)
if
args
.
save_path
:
scores
[
"predictions"
]
=
output_list
scores
[
"predictions"
]
=
predictions
scores
[
"ground_truths"
]
=
ground_truths
save_json
(
args
.
save_path
,
scores
)
return
output_list
return
predictions
if
__name__
==
"__main__"
:
...
...
@@ -84,4 +97,4 @@ if __name__ == "__main__":
if
args
.
task_name
is
None
:
args
.
task_name
=
os
.
path
.
basename
(
args
.
dataset_name_or_path
)
predicts
=
test
(
args
)
predict
ion
s
=
test
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment