Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
efa810f0
Commit
efa810f0
authored
Feb 03, 2021
by
thefazzer
Browse files
Score computation, use squad metrics
parent
5552c8dc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
21 deletions
+38
-21
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+38
-21
No files found.
lm_eval/tasks/coqa.py
View file @
efa810f0
...
...
@@ -6,6 +6,7 @@ import numpy as np
from
lm_eval.base
import
Dataset
,
rf
,
mean
from
..utils
import
sh
from
itertools
import
zip_longest
import
transformers.data.metrics.squad_metrics
as
squad_metrics
class
CoQA
(
Dataset
):
def
download
(
self
):
...
...
@@ -39,16 +40,18 @@ class CoQA(Dataset):
return
"Given a passage and a conversation so far, answer the next question in the conversation."
def
doc_to_text
(
self
,
doc
):
# Each "doc" is a story and conversation (Q and A pairs).
doc_text
=
doc
[
"story"
]
+
'
\n\n
'
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
],
doc
[
"answers"
][:
-
1
]):
# omit target answer
question
=
f
"Q:
{
q
[
'input_text'
]
}
"
+
'
\n\n
'
answer
=
f
"A:
{
a
[
'input_text'
]
}
"
+
'
\n\n
'
if
a
is
not
None
else
"A:
\n\n
"
doc_text
+=
question
+
answer
print
(
doc_text
)
return
doc_text
@
classmethod
def
get_answers
(
cls
,
doc
,
turn_id
):
#
get
answer
s
and valid alternatives
#
This function returns an
answer and valid alternatives
.
answers
=
[]
answer_forturn
=
doc
[
"answers"
][
turn_id
-
1
][
"input_text"
]
answers
.
append
(
answer_forturn
)
...
...
@@ -62,12 +65,27 @@ class CoQA(Dataset):
return
answers
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
#
d
efault to predict last turn
#
D
efault to predict last turn
.
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
])
all_answers
=
self
.
get_answers
(
doc
,
turnid
)
return
all_answers
[
0
]
# ignore alternative answers for now
@
staticmethod
def
compute_scores
(
gold_list
,
pred
):
f1_sum
=
0.0
em_sum
=
0.0
if
len
(
gold_list
)
>
1
:
for
i
in
range
(
len
(
gold_list
)):
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_answers
)
else
:
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_list
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_list
)
return
{
'em'
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
'f1'
:
f1_sum
/
max
(
1
,
len
(
gold_list
))}
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
...
...
@@ -80,8 +98,9 @@ class CoQA(Dataset):
part of the document for `doc`.
"""
requests
=
[]
for
answer
in
self
.
get_answers
(
doc
,
len
(
doc
[
"questions"
])):
requests
.
append
(
rf
.
loglikelihood
(
ctx
,
" "
+
answer
))
for
answers
in
self
.
get_answers
(
doc
,
len
(
doc
[
"questions"
])):
for
a
in
answers
:
requests
.
append
(
rf
.
loglikelihood
(
ctx
,
" "
+
a
))
return
requests
def
process_results
(
self
,
doc
,
results
):
...
...
@@ -94,28 +113,26 @@ class CoQA(Dataset):
:param results:
The results of the requests created in construct_requests.
"""
gold
=
self
.
get_answers
(
doc
,
len
(
doc
[
"questions"
]))
turn_id
=
len
(
doc
[
"questions"
])
gold_list
=
self
.
get_answers
(
doc
,
turn_id
)
pred
=
np
.
argmax
(
results
)
(
em
,
f1
)
=
self
.
compute_scores
(
gold_list
,
pred
)
return
{
"acc"
:
int
(
pred
==
gold
)
"f1"
:
f1
,
"em"
:
em
,
}
def
aggregation
(
self
):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
def
higher_is_better
(
self
):
return
{
"acc"
:
mean
"f1"
:
True
,
"em"
:
True
,
}
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
def
aggregation
(
self
):
return
{
"acc"
:
True
"f1"
:
mean
,
"em"
:
mean
,
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment