Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
donut_pytorch
Commits
4f4810f0
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "dcbfd93d7aeb14f8ff08a48866d2a68950d4c69a"
Commit
4f4810f0
authored
Aug 19, 2022
by
Geewook Kim
Browse files
feat: add functions to calculate f1 accuracy score
parent
95cde5a9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
18 deletions
+80
-18
donut/util.py
donut/util.py
+58
-9
test.py
test.py
+22
-9
No files found.
donut/util.py
View file @
4f4810f0
...
@@ -6,6 +6,7 @@ MIT License
...
@@ -6,6 +6,7 @@ MIT License
import
json
import
json
import
os
import
os
import
random
import
random
from
collections
import
defaultdict
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
import
torch
import
torch
...
@@ -31,7 +32,7 @@ class DonutDataset(Dataset):
...
@@ -31,7 +32,7 @@ class DonutDataset(Dataset):
"""
"""
DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string)
.
and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string)
Args:
Args:
dataset_name_or_path: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl
dataset_name_or_path: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl
...
@@ -94,7 +95,7 @@ class DonutDataset(Dataset):
...
@@ -94,7 +95,7 @@ class DonutDataset(Dataset):
def
__getitem__
(
self
,
idx
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
def
__getitem__
(
self
,
idx
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Load image from image_path of given dataset_path and convert into input_tensor and labels
Load image from image_path of given dataset_path and convert into input_tensor and labels
.
Convert gt data into input_ids (tokenized string)
Convert gt data into input_ids (tokenized string)
Returns:
Returns:
...
@@ -136,18 +137,50 @@ class DonutDataset(Dataset):
...
@@ -136,18 +137,50 @@ class DonutDataset(Dataset):
class
JSONParseEvaluator
:
class
JSONParseEvaluator
:
"""
"""
Calculate n-TED(Normalized Tree Edit Distance) based accuracy between a predicted json and a gold json,
Calculate n-TED(Normalized Tree Edit Distance) based accuracy and F1 accuracy score
calculated as,
accuracy = 1 - TED(normalize(pred), normalize(gold)) / TED({}, normalize(gold))
"""
"""
@
staticmethod
def
flatten
(
data
:
dict
):
"""
Convert Dictionary into Non-nested Dictionary
Example:
input(dict)
{
"menu": [
{"name" : ["cake"], "count" : ["2"]},
{"name" : ["juice"], "count" : ["1"]},
]
}
output(dict)
{
"menu.name": ["cake", "juice"],
"menu.count": ["2", "1"],
}
"""
flatten_data
=
defaultdict
(
list
)
def
_flatten
(
value
,
key
=
""
):
if
type
(
value
)
is
dict
:
for
child_key
,
child_value
in
value
.
items
():
_flatten
(
child_value
,
f
"
{
key
}
.
{
child_key
}
"
if
key
else
child_key
)
elif
type
(
value
)
is
list
:
for
value_item
in
value
:
_flatten
(
value_item
,
key
)
else
:
flatten_data
[
key
].
append
(
value
)
_flatten
(
data
)
return
dict
(
flatten_data
)
@
staticmethod
@
staticmethod
def
update_cost
(
label1
:
str
,
label2
:
str
):
def
update_cost
(
label1
:
str
,
label2
:
str
):
"""
"""
Update cost for tree edit distance.
Update cost for tree edit distance.
If both are leaf node, calculate string edit distance between two labels (special token '<leaf>' will be ignored).
If both are leaf node, calculate string edit distance between two labels (special token '<leaf>' will be ignored).
If one of them is leaf node, cost is length of string in leaf node + 1.
If one of them is leaf node, cost is length of string in leaf node + 1.
If neither are leaf node, cost is 0 if label1 is same with label2 othewise 1
.
If neither are leaf node, cost is 0 if label1 is same with label2 othewise 1
"""
"""
label1_leaf
=
"<leaf>"
in
label1
label1_leaf
=
"<leaf>"
in
label1
label2_leaf
=
"<leaf>"
in
label2
label2_leaf
=
"<leaf>"
in
label2
...
@@ -161,7 +194,7 @@ class JSONParseEvaluator:
...
@@ -161,7 +194,7 @@ class JSONParseEvaluator:
return
int
(
label1
!=
label2
)
return
int
(
label1
!=
label2
)
@
staticmethod
@
staticmethod
def
insert_and_remove_cost
(
node
):
def
insert_and_remove_cost
(
node
:
Node
):
"""
"""
Insert and remove cost for tree edit distance.
Insert and remove cost for tree edit distance.
If leaf node, cost is length of label name.
If leaf node, cost is length of label name.
...
@@ -175,7 +208,7 @@ class JSONParseEvaluator:
...
@@ -175,7 +208,7 @@ class JSONParseEvaluator:
def
normalize_dict
(
self
,
data
:
Union
[
Dict
,
List
,
Any
]):
def
normalize_dict
(
self
,
data
:
Union
[
Dict
,
List
,
Any
]):
"""
"""
Sort by value, while iterate over element if data is list
.
Sort by value, while iterate over element if data is list
"""
"""
if
not
data
:
if
not
data
:
return
{}
return
{}
...
@@ -203,6 +236,22 @@ class JSONParseEvaluator:
...
@@ -203,6 +236,22 @@ class JSONParseEvaluator:
return
new_data
return
new_data
def
cal_f1
(
self
,
preds
:
List
[
dict
],
answers
:
List
[
dict
]):
"""
Calculate global F1 accuracy score (field-level, micro-averaged) by counting all true positives, false negatives and false positives
"""
total_tp
,
total_fn_or_fp
=
0
,
0
for
pred
,
answer
in
zip
(
preds
,
answers
):
pred
,
answer
=
self
.
flatten
(
self
.
normalize_dict
(
pred
)),
self
.
flatten
(
self
.
normalize_dict
(
answer
))
for
pred_key
,
pred_values
in
pred
.
items
():
for
pred_value
in
pred_values
:
if
pred_key
in
answer
and
pred_value
in
answer
[
pred_key
]:
answer
[
pred_key
].
remove
(
pred_value
)
total_tp
+=
1
else
:
total_fn_or_fp
+=
1
return
total_tp
/
(
total_tp
+
(
total_fn_or_fp
)
/
2
)
def
construct_tree_from_dict
(
self
,
data
:
Union
[
Dict
,
List
],
node_name
:
str
=
None
):
def
construct_tree_from_dict
(
self
,
data
:
Union
[
Dict
,
List
],
node_name
:
str
=
None
):
"""
"""
Convert Dictionary into Tree
Convert Dictionary into Tree
...
@@ -252,7 +301,7 @@ class JSONParseEvaluator:
...
@@ -252,7 +301,7 @@ class JSONParseEvaluator:
raise
Exception
(
data
,
node_name
)
raise
Exception
(
data
,
node_name
)
return
node
return
node
def
cal_acc
(
self
,
pred
,
answer
):
def
cal_acc
(
self
,
pred
:
dict
,
answer
:
dict
):
"""
"""
Calculate normalized tree edit distance(nTED) based accuracy.
Calculate normalized tree edit distance(nTED) based accuracy.
1) Construct tree from dict,
1) Construct tree from dict,
...
...
test.py
View file @
4f4810f0
...
@@ -32,9 +32,11 @@ def test(args):
...
@@ -32,9 +32,11 @@ def test(args):
if
args
.
save_path
:
if
args
.
save_path
:
os
.
makedirs
(
os
.
path
.
dirname
(
args
.
save_path
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
args
.
save_path
),
exist_ok
=
True
)
output_list
=
[]
predictions
=
[]
ground_truths
=
[]
accs
=
[]
accs
=
[]
evaluator
=
JSONParseEvaluator
()
dataset
=
load_dataset
(
args
.
dataset_name_or_path
,
split
=
args
.
split
)
dataset
=
load_dataset
(
args
.
dataset_name_or_path
,
split
=
args
.
split
)
for
idx
,
sample
in
tqdm
(
enumerate
(
dataset
),
total
=
len
(
dataset
)):
for
idx
,
sample
in
tqdm
(
enumerate
(
dataset
),
total
=
len
(
dataset
)):
...
@@ -52,24 +54,35 @@ def test(args):
...
@@ -52,24 +54,35 @@ def test(args):
gt
=
ground_truth
[
"gt_parse"
]
gt
=
ground_truth
[
"gt_parse"
]
score
=
float
(
output
[
"class"
]
==
gt
[
"class"
])
score
=
float
(
output
[
"class"
]
==
gt
[
"class"
])
elif
args
.
task_name
==
"docvqa"
:
elif
args
.
task_name
==
"docvqa"
:
score
=
0.0
# note: docvqa is evaluated on the official website
# Note: we evaluated the model on the official website.
# In this script, an exact-match based score will be returned instead
gt
=
ground_truth
[
"gt_parses"
]
answers
=
set
([
qa_parse
[
"answer"
]
for
qa_parse
in
gt
])
score
=
float
(
output
[
"answer"
]
in
answers
)
else
:
else
:
gt
=
ground_truth
[
"gt_parse"
]
gt
=
ground_truth
[
"gt_parse"
]
evaluator
=
JSONParseEvaluator
()
score
=
evaluator
.
cal_acc
(
output
,
gt
)
score
=
evaluator
.
cal_acc
(
output
,
gt
)
accs
.
append
(
score
)
accs
.
append
(
score
)
output_list
.
append
(
output
)
predictions
.
append
(
output
)
ground_truths
.
append
(
gt
)
scores
=
{
"accuracies"
:
accs
,
"mean_accuracy"
:
np
.
mean
(
accs
)}
scores
=
{
print
(
scores
,
f
"length :
{
len
(
accs
)
}
"
)
"ted_accuracies"
:
accs
,
"ted_accuracy"
:
np
.
mean
(
accs
),
"f1_accuracy"
:
evaluator
.
cal_f1
(
predictions
,
ground_truths
),
}
print
(
f
"Total number of samples:
{
len
(
accs
)
}
, Tree Edit Distance (TED) based accuracy score:
{
scores
[
'ted_accuracy'
]
}
, F1 accuracy score:
{
scores
[
'f1_accuracy'
]
}
"
)
if
args
.
save_path
:
if
args
.
save_path
:
scores
[
"predictions"
]
=
output_list
scores
[
"predictions"
]
=
predictions
scores
[
"ground_truths"
]
=
ground_truths
save_json
(
args
.
save_path
,
scores
)
save_json
(
args
.
save_path
,
scores
)
return
output_list
return
predictions
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -84,4 +97,4 @@ if __name__ == "__main__":
...
@@ -84,4 +97,4 @@ if __name__ == "__main__":
if
args
.
task_name
is
None
:
if
args
.
task_name
is
None
:
args
.
task_name
=
os
.
path
.
basename
(
args
.
dataset_name_or_path
)
args
.
task_name
=
os
.
path
.
basename
(
args
.
dataset_name_or_path
)
predicts
=
test
(
args
)
predict
ion
s
=
test
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment