Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
2cfdd80a
Commit
2cfdd80a
authored
Feb 13, 2021
by
Jason Phang
Browse files
ReCoRD fixup
parent
e4e9228e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
17 deletions
+17
-17
lm_eval/tasks/superglue.py
lm_eval/tasks/superglue.py
+17
-17
No files found.
lm_eval/tasks/superglue.py
View file @
2cfdd80a
...
...
@@ -272,26 +272,25 @@ class ReCoRD(HFTask):
def
training_docs
(
self
):
# In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing.
# Each doc consists of multiple answer candidates, each of which is scored yes/no.
# Hence, we create one "doc" for each (context + passage, answer) pair.
# Moreover, we only use the correct answers for context packing
# (This is not an issue for evaluation, where we can directly score multiple candidates at once).
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
[]
for
doc
in
self
.
data
[
"train"
]:
for
entity
in
list
(
set
(
doc
[
"entities"
])):
self
.
_training_docs
.
append
({
"passage"
:
doc
[
"passage"
],
"query"
:
doc
[
"query"
],
"entity"
:
entity
,
"label"
:
entity
in
doc
[
"answers"
],
})
self
.
_training_docs
.
append
(
self
.
_process_doc
(
doc
))
return
self
.
_training_docs
def
validation_docs
(
self
):
# Following from .trianing_docs, for validation_docs, each document corresponds to
# the original doc from the dataset, i.e. comprises of lists of entities, and which
# entities are correct (potentially multiple)
yield
from
self
.
data
[
"validation"
]
# See: training_docs
for
doc
in
self
.
data
[
"validation"
]:
yield
self
.
_process_doc
(
doc
)
@
classmethod
def
_process_doc
(
cls
,
doc
):
return
{
"passage"
:
doc
[
"passage"
],
"query"
:
doc
[
"query"
],
"entities"
:
sorted
(
list
(
set
(
doc
[
"entities"
]))),
"answers"
:
sorted
(
list
(
set
(
doc
[
"answers"
]))),
}
def
doc_to_text
(
self
,
doc
):
initial_text
,
*
highlights
=
doc
[
"passage"
].
strip
().
split
(
"
\n
@highlight
\n
"
)
...
...
@@ -305,7 +304,8 @@ class ReCoRD(HFTask):
return
f
' -
{
query
}
'
.
replace
(
"@placeholder"
,
entity
)
def
doc_to_target
(
self
,
doc
):
return
self
.
format_answer
(
query
=
doc
[
"query"
],
entity
=
doc
[
"entity"
])
# We only output the first correct entity in a doc
return
self
.
format_answer
(
query
=
doc
[
"query"
],
entity
=
doc
[
"answers"
][
0
])
def
construct_requests
(
self
,
doc
,
ctx
):
requests
=
[
...
...
@@ -319,10 +319,10 @@ class ReCoRD(HFTask):
# - Pick the maximum likelihood prediction entity
# - Evaluate the accuracy and token F1 PER EXAMPLE
# - Average over all examples
max_idx
=
np
.
argmax
(
np
.
array
(
results
))
max_idx
=
np
.
argmax
(
np
.
array
(
[
result
[
0
]
for
result
in
results
]
))
prediction
=
doc
[
"entities"
][
max_idx
]
gold_label_set
=
list
(
set
(
doc
[
"answers"
]
))
gold_label_set
=
doc
[
"answers"
]
f1
=
metric_max_over_ground_truths
(
squad_metrics
.
compute_f1
,
prediction
,
gold_label_set
)
em
=
metric_max_over_ground_truths
(
squad_metrics
.
compute_exact
,
prediction
,
gold_label_set
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment