Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
c6a35696
Commit
c6a35696
authored
Feb 11, 2022
by
Stephen Hogg
Browse files
Include extraction in process_results; fixes per test results
parent
c2f12474
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
106 additions
and
39 deletions
+106
-39
lm_eval/tasks/qasper.py
lm_eval/tasks/qasper.py
+106
-39
No files found.
lm_eval/tasks/qasper.py
View file @
c6a35696
...
...
@@ -22,12 +22,86 @@ https://arxiv.org/abs/2105.03011
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""
from
collections
import
Counter
from
math
import
exp
import
re
import
string
from
lm_eval.base
import
rf
from
lm_eval.metrics
import
f1_score
from
lm_eval.metrics
import
f1_score
,
mean
from
.common
import
HFTask
def
normalize_answer
(
s
):
"""
Taken from the official evaluation script for v1.1 of the SQuAD dataset.
Lower text and remove punctuation, articles and extra whitespace.
"""
def
remove_articles
(
text
):
return
re
.
sub
(
r
"\b(a|an|the)\b"
,
" "
,
text
)
def
white_space_fix
(
text
):
return
" "
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
""
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
categorise_answer
(
answer_blob
):
if
answer_blob
[
"unanswerable"
]:
answer
=
"unanswerable"
answer_type
=
"unanswerable"
return
answer
,
answer_type
elif
answer_blob
[
"yes_no"
]:
answer
=
"Yes"
answer_type
=
"bool"
return
answer
,
answer_type
elif
answer_blob
[
"free_form_answer"
]:
answer
=
answer_blob
[
"free_form_answer"
]
answer_type
=
"free form answer"
return
answer
,
answer_type
elif
answer_blob
[
"extractive_spans"
]:
answer
=
answer_blob
[
"extractive_spans"
]
answer_type
=
"extractive spans"
return
answer
,
answer_type
elif
answer_blob
[
"yes_no"
]
is
False
:
answer
=
"No"
answer_type
=
"bool"
return
answer
,
answer_type
def
token_f1_score
(
prediction
,
ground_truth
):
"""
Taken from the official evaluation script for v1.1 of the SQuAD dataset.
"""
prediction_tokens
=
normalize_answer
(
prediction
).
split
()
ground_truth_tokens
=
normalize_answer
(
ground_truth
).
split
()
common
=
Counter
(
prediction_tokens
)
&
Counter
(
ground_truth_tokens
)
num_same
=
sum
(
common
.
values
())
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
prediction_tokens
)
recall
=
1.0
*
num_same
/
len
(
ground_truth_tokens
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
def
paragraph_f1_score
(
prediction
,
ground_truth
):
num_same
=
len
(
set
(
ground_truth
).
intersection
(
set
(
prediction
)))
if
num_same
==
0
:
return
0.0
precision
=
num_same
/
len
(
prediction
)
recall
=
num_same
/
len
(
ground_truth
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
class
QASPER
(
HFTask
):
VERSION
=
0
DATASET_PATH
=
"qasper"
...
...
@@ -50,7 +124,7 @@ class QASPER(HFTask):
def
doc_to_target
(
self
,
doc
):
# this method is invoked by tests only
return
" "
+
doc
[
"answer
_str
"
]
return
" "
+
doc
[
"answer"
]
def
training_docs
(
self
):
for
doc
in
self
.
data
[
"train"
]:
...
...
@@ -67,33 +141,18 @@ class QASPER(HFTask):
https://github.com/allenai/qasper-led-baseline/blob/main/scripts/evaluator.py
"""
obs_list
=
[]
for
qa
in
doc
[
"qas"
]:
for
question
,
answer_list
in
zip
(
qa
[
"question"
],
qa
[
"answers"
]):
for
answer
in
answer_list
:
if
answer
[
"unanswerable"
]:
answer_str
=
"unanswerable"
answer_type
=
"unanswerable"
elif
answer
[
"yes_no"
]:
answer_str
=
"Yes"
answer_type
=
"bool"
elif
answer
[
"yes_no"
]
is
not
None
:
answer_str
=
"No"
answer_type
=
"bool"
elif
answer
[
"free_form_answer"
]:
answer_str
=
answer
[
"free_form_answer"
]
answer_type
=
"free form answer"
elif
answer
[
"extractive_spans"
]:
answer_str
=
", "
.
join
(
answer
[
"extractive_spans"
])
answer_type
=
"extractive spans"
obs_list
.
append
[
{
"title"
:
doc
[
"title"
],
"abstract"
:
doc
[
"abstract"
],
"question"
:
question
,
"answer_str"
:
answer_str
,
"answer_type"
:
answer_type
,
}
]
for
question
,
answer_list
in
zip
(
doc
[
"qas"
][
"question"
],
doc
[
"qas"
][
"answers"
]):
for
answer_blob
in
answer_list
[
"answer"
]:
answer
,
answer_type
=
categorise_answer
(
answer_blob
)
obs_list
.
append
(
{
"title"
:
doc
[
"title"
],
"abstract"
:
doc
[
"abstract"
],
"question"
:
question
,
"answer"
:
answer
,
"answer_type"
:
answer_type
,
}
)
return
obs_list
def
process_results
(
self
,
doc
,
results
):
...
...
@@ -114,16 +173,15 @@ class QASPER(HFTask):
# Handle completions
if
doc
[
"answer_type"
]
==
"free form answer"
:
res_dict
[
"f1_ab"
]
=
None
res_dict
[
"f1_ab"
]
=
token_f1_score
(
res
[
"answer"
],
doc
[
"answer"
])
# Handle extraction
if
doc
[
"answer_type"
]
==
"extractive spans"
:
res_dict
[
"f1_ex"
]
=
paragraph_f1_score
(
res
[
"answer"
],
doc
[
"answer"
])
return
res_dict
def
aggregation
(
self
):
return
{
"f1_un"
:
f1_score
,
"f1_yn"
:
f1_score
,
"f1_ab"
:
f1_score
,
"f1_ex"
:
f1_score
,
}
return
{
"f1_un"
:
f1_score
,
"f1_yn"
:
f1_score
,
"f1_ab"
:
mean
,
"f1_ex"
:
mean
}
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
...
...
@@ -138,9 +196,18 @@ class QASPER(HFTask):
"""
unanswerable
=
rf
.
loglikelihood
(
ctx
,
" "
+
"unanswerable"
)
if
doc
[
"answer_type"
]
in
(
"free form answer"
,
"extractive spans"
):
re
s
=
rf
.
greedy_until
(
ctx
,
[
"
\n
"
])
re
turn
rf
.
greedy_until
(
ctx
,
[
"
\n
"
])
,
unanswerable
elif
doc
[
"answer_type"
]
in
(
"bool"
):
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
" yes"
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
" no"
)
res
=
(
ll_yes
,
ll_no
)
return
res
,
unanswerable
return
ll_yes
,
ll_no
,
unanswerable
else
:
return
unanswerable
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"f1_un"
:
True
,
"f1_yn"
:
True
,
"f1_ab"
:
True
,
"f1_ex"
:
True
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment