Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
b9b3159b
Commit
b9b3159b
authored
Feb 08, 2021
by
thefazzer
Browse files
Bugfixes, answer mapping, comments
parent
602d3e20
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
59 additions
and
31 deletions
+59
-31
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+59
-31
No files found.
lm_eval/tasks/coqa.py
View file @
b9b3159b
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import
json
import
random
import
numpy
as
np
from
lm_eval.base
import
Task
,
rf
,
mean
from
..utils
import
sh
from
itertools
import
zip_longest
import
transformers.data.metrics.squad_metrics
as
squad_metrics
import
collections
import
datasets
import
numpy
as
np
from
lm_eval.base
import
rf
,
mean
from
.
common
import
HFTask
from
tqdm
import
tqdm
import
string
,
re
class
CoQA
(
Task
):
def
download
(
self
):
pass
# -N only overwrites if the remote file has changed
...
...
@@ -28,10 +34,14 @@ class CoQA(Task):
return
False
def
training_docs
(
self
):
return
json
.
load
(
open
(
'data/coqa/coqa-train-v1.0.json'
))[
'data'
]
doc_data
=
json
.
load
(
open
(
'data/coqa/coqa-train-v1.0.json'
))[
'data'
]
for
doc
in
doc_data
:
for
answer
in
doc
[
'answers'
]:
answer
[
'input_text'
]
=
self
.
get_answer_choice
(
answer
[
'input_text'
])
return
doc_data
def
validation_docs
(
self
):
return
json
.
load
(
open
(
'data/coqa/coqa-dev-v1.0.json'
))[
'data'
]
return
json
.
load
(
open
(
'data/coqa/coqa-dev-v1.0.json'
))[
'data'
]
def
test_docs
(
self
):
pass
...
...
@@ -40,44 +50,55 @@ class CoQA(Task):
return
"Given a passage and a conversation so far, answer the next question in the conversation."
def
doc_to_text
(
self
,
doc
):
# Each "doc" is a story and conversation (Q and A pairs).
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text
=
doc
[
"story"
]
+
'
\n\n
'
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
],
doc
[
"answers"
][:
-
1
]):
# omit target answer
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
],
doc
[
"answers"
][:
-
1
]):
# omit target answer
ai
question
=
f
"Q:
{
q
[
'input_text'
]
}
"
+
'
\n\n
'
answer
=
f
"A:
{
a
[
'input_text'
]
}
"
+
'
\n\n
'
if
a
is
not
None
else
"A: "
doc_text
+=
question
+
answer
print
(
doc_text
)
return
doc_text
@
classmethod
def
get_answers
(
cls
,
doc
,
turn_id
):
#
This function r
eturns
an
answer and valid alternatives.
#
R
eturns
unique
answer
s
and valid alternatives
(Some questions in CoQA have multiple valid answers)
.
answers
=
[]
answer_forturn
=
doc
[
"answers"
][
turn_id
-
1
][
"input_text"
]
answers
.
append
(
answer_forturn
)
additionals
=
doc
.
get
(
"additional_answers"
)
if
additionals
:
for
key
in
additionals
:
additional_answer_for_turn
=
additionals
[
key
][
turn_id
-
1
][
"input_text"
]
if
additional_answer_for_turn
.
upp
er
()
not
in
map
(
str
.
upp
er
,
answers
):
additional
_answer
s
=
doc
.
get
(
"additional_answers"
)
if
additional
_answer
s
:
for
key
in
additional
_answer
s
:
additional_answer_for_turn
=
additional
_answer
s
[
key
][
turn_id
-
1
][
"input_text"
]
if
additional_answer_for_turn
.
low
er
()
not
in
map
(
str
.
low
er
,
answers
):
answers
.
append
(
additional_answer_for_turn
)
return
answers
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
# Default to predict last turn.
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
])
all_answers
=
self
.
get_answers
(
doc
,
turnid
)
return
all_answers
[
0
]
# ignore alternative answers for now
@
classmethod
def
get_answer_choice
(
self
,
raw_text
):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if
raw_text
==
"unknown"
:
return
'0'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"yes"
:
return
'1'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"no"
:
return
'2'
return
'3'
# Not a yes/no question
@
staticmethod
def
compute_scores
(
gold_list
,
pred
):
# tests for exact match and on the normalised answer (compute_exact)
# test for overlap (compute_f1)
f1_sum
=
0.0
em_sum
=
0.0
if
len
(
gold_list
)
>
1
:
for
i
in
range
(
len
(
gold_list
)):
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
# predictions compared against (n) golds and take maximum
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_answers
)
else
:
...
...
@@ -86,6 +107,14 @@ class CoQA(Task):
return
{
'em'
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
'f1'
:
f1_sum
/
max
(
1
,
len
(
gold_list
))}
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
# Default to prediction of last turn.
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
])
raw_text
=
doc
[
'answers'
][
turnid
-
1
][
"input_text"
]
return
self
.
get_answer_choice
(
raw_text
)
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
...
...
@@ -97,12 +126,12 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
requests
=
[
]
for
answers
in
self
.
get_answers
(
doc
,
len
(
doc
[
"questions"
])):
for
a
in
answers
:
requests
.
append
(
rf
.
loglikelihood
(
ctx
,
" "
+
a
))
return
requests
ll_
requests
=
[
rf
.
loglikelihood
(
ctx
,
" "
+
i
)
for
i
in
[
'0'
,
'1'
,
'2'
,
'3'
]
]
return
ll_
requests
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
...
...
@@ -113,16 +142,15 @@ class CoQA(Task):
:param results:
The results of the requests created in construct_requests.
"""
turn_id
=
len
(
doc
[
"questions"
])
gold_list
=
self
.
get_answers
(
doc
,
turn_id
)
pred
=
np
.
argmax
(
results
)
gold_list
=
[
self
.
get_answer_choice
(
r_text
)
for
r_text
in
self
.
get_answers
(
doc
,
turn_id
)
]
pred
=
str
(
np
.
argmax
(
results
)
)
(
em
,
f1
)
=
self
.
compute_scores
(
gold_list
,
pred
)
scores
=
self
.
compute_scores
(
gold_list
,
pred
)
return
{
"f1"
:
f1
,
"em"
:
em
,
"f1"
:
scores
[
'f1'
]
,
"em"
:
scores
[
'em'
]
,
}
def
higher_is_better
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment