Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
donut_pytorch
Commits
86bcafe1
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "912fdff899cf0fd674ed357e46a0209311aefad2"
Commit
86bcafe1
authored
Aug 23, 2022
by
Geewook Kim
Browse files
feat: update JSONParseEvaluator
parent
d2fd95a3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
19 deletions
+21
-19
donut/util.py
donut/util.py
+21
-19
No files found.
donut/util.py
View file @
86bcafe1
...
@@ -137,14 +137,13 @@ class DonutDataset(Dataset):
...
@@ -137,14 +137,13 @@ class DonutDataset(Dataset):
class
JSONParseEvaluator
:
class
JSONParseEvaluator
:
"""
"""
Calculate n-TED(Normalized Tree Edit Distance) based accuracy and
F1 accuracy score
Calculate n-TED(Normalized Tree Edit Distance) based accuracy and F1 accuracy score
"""
"""
@
staticmethod
@
staticmethod
def
flatten
(
data
:
dict
):
def
flatten
(
data
:
dict
):
"""
"""
Convert Dictionary into Non-nested Dictionary
Convert Dictionary into Non-nested Dictionary
Example:
Example:
input(dict)
input(dict)
{
{
...
@@ -153,13 +152,15 @@ class JSONParseEvaluator:
...
@@ -153,13 +152,15 @@ class JSONParseEvaluator:
{"name" : ["juice"], "count" : ["1"]},
{"name" : ["juice"], "count" : ["1"]},
]
]
}
}
output(dict)
output(list)
{
[
"menu.name": ["cake", "juice"],
("menu.name", "cake"),
"menu.count": ["2", "1"],
("menu.count", "2"),
}
("menu.name", "juice"),
("menu.count", "1"),
]
"""
"""
flatten_data
=
defaultdict
(
list
)
flatten_data
=
list
(
)
def
_flatten
(
value
,
key
=
""
):
def
_flatten
(
value
,
key
=
""
):
if
type
(
value
)
is
dict
:
if
type
(
value
)
is
dict
:
...
@@ -169,10 +170,10 @@ class JSONParseEvaluator:
...
@@ -169,10 +170,10 @@ class JSONParseEvaluator:
for
value_item
in
value
:
for
value_item
in
value
:
_flatten
(
value_item
,
key
)
_flatten
(
value_item
,
key
)
else
:
else
:
flatten_data
[
key
]
.
append
(
value
)
flatten_data
.
append
(
(
key
,
value
)
)
_flatten
(
data
)
_flatten
(
data
)
return
dict
(
flatten_data
)
return
flatten_data
@
staticmethod
@
staticmethod
def
update_cost
(
label1
:
str
,
label2
:
str
):
def
update_cost
(
label1
:
str
,
label2
:
str
):
...
@@ -225,10 +226,11 @@ class JSONParseEvaluator:
...
@@ -225,10 +226,11 @@ class JSONParseEvaluator:
elif
isinstance
(
data
,
list
):
elif
isinstance
(
data
,
list
):
if
all
(
isinstance
(
item
,
dict
)
for
item
in
data
):
if
all
(
isinstance
(
item
,
dict
)
for
item
in
data
):
new_data
=
[]
new_data
=
[]
for
item
in
sorted
(
data
,
key
=
lambda
x
:
str
(
sorted
(
x
.
items
())))
:
for
item
in
data
:
item
=
self
.
normalize_dict
(
item
)
item
=
self
.
normalize_dict
(
item
)
if
item
:
if
item
:
new_data
.
append
(
item
)
new_data
.
append
(
item
)
new_data
=
sorted
(
new_data
,
key
=
lambda
x
:
str
(
x
.
keys
())
+
str
(
x
.
values
()))
else
:
else
:
new_data
=
sorted
([
str
(
item
)
for
item
in
data
if
type
(
item
)
in
{
str
,
int
,
float
}
and
str
(
item
)])
new_data
=
sorted
([
str
(
item
)
for
item
in
data
if
type
(
item
)
in
{
str
,
int
,
float
}
and
str
(
item
)])
else
:
else
:
...
@@ -243,14 +245,14 @@ class JSONParseEvaluator:
...
@@ -243,14 +245,14 @@ class JSONParseEvaluator:
total_tp
,
total_fn_or_fp
=
0
,
0
total_tp
,
total_fn_or_fp
=
0
,
0
for
pred
,
answer
in
zip
(
preds
,
answers
):
for
pred
,
answer
in
zip
(
preds
,
answers
):
pred
,
answer
=
self
.
flatten
(
self
.
normalize_dict
(
pred
)),
self
.
flatten
(
self
.
normalize_dict
(
answer
))
pred
,
answer
=
self
.
flatten
(
self
.
normalize_dict
(
pred
)),
self
.
flatten
(
self
.
normalize_dict
(
answer
))
for
pred_key
,
pred_values
in
pred
.
items
()
:
for
field
in
pred
:
for
pred_value
in
pred_values
:
if
field
in
answer
:
if
pred_key
in
answer
and
pred_value
in
answer
[
pred_key
]:
total_tp
+=
1
answer
[
pred_key
].
remove
(
pred_value
)
answer
.
remove
(
field
)
total_tp
+=
1
else
:
else
:
total_fn_or_fp
+=
1
total_fn_or_fp
+=
1
total_fn_or_fp
+=
len
(
answer
)
return
total_tp
/
(
total_tp
+
(
total_fn_or_fp
)
/
2
)
return
total_tp
/
(
total_tp
+
total_fn_or_fp
/
2
)
def
construct_tree_from_dict
(
self
,
data
:
Union
[
Dict
,
List
],
node_name
:
str
=
None
):
def
construct_tree_from_dict
(
self
,
data
:
Union
[
Dict
,
List
],
node_name
:
str
=
None
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment