Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
c0862026
Commit
c0862026
authored
Jan 31, 2021
by
thefazzer
Browse files
Text & target impl, support fns, refactoring
parent
6738b241
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
30 deletions
+48
-30
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+48
-30
No files found.
lm_eval/tasks/coqa.py
View file @
c0862026
...
@@ -2,25 +2,21 @@
...
@@ -2,25 +2,21 @@
import
json
import
json
import
random
import
random
from
lm_eval.base
import
Dataset
import
numpy
as
np
from
lm_eval.base
import
Dataset
,
rf
,
mean
from
..utils
import
sh
from
..utils
import
sh
import
itertools
from
itertools
import
zip_longest
class
CoQA
(
Dataset
):
class
CoQA
(
Dataset
):
def
__init__
(
self
):
self
.
download
()
def
download
(
self
):
def
download
(
self
):
pass
# -N only overwrites if the remote file has changed
sh
(
"""
sh
(
"""
mkdir -p data/coqa
mkdir -p data/coqa
wget -
-no-clobber
http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
wget -
N
http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
wget -
-no-clobber
http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json
wget -
N
http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json
"""
)
"""
)
@
classmethod
def
get_answers
(
cls
,
doc
,
turn_id
):
answers
=
zip
(
doc
[
"answers"
],
zip
(
doc
[
"additional_answers"
]))
return
answers
[
turn_id
-
1
]
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
return
True
return
True
...
@@ -43,12 +39,34 @@ class CoQA(Dataset):
...
@@ -43,12 +39,34 @@ class CoQA(Dataset):
return
"Given a passage and a conversation so far, answer the next question in the conversation."
return
"Given a passage and a conversation so far, answer the next question in the conversation."
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
qa_pairs
=
[(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
],
doc
[
"answers"
][:
-
1
])]
# truncate target answer
doc_text
=
doc
[
"story"
]
+
'
\n\n
'
return
"{}
\n\n
{}"
.
format
(
doc
[
"story"
],
f
"Q:
{
q
}
"
+
'
\n\n
'
+
f
"A:
{
a
}
"
)
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
],
doc
[
"answers"
][:
-
1
]):
# omit target answer
question
=
f
"Q:
{
q
[
'input_text'
]
}
"
+
'
\n\n
'
answer
=
f
"A:
{
a
[
'input_text'
]
}
"
+
'
\n\n
'
if
a
is
not
None
else
"A:
\n\n
"
doc_text
+=
question
+
answer
return
doc_text
def
doc_to_target
(
self
,
doc
):
@
classmethod
# TODO: all distinct answers taking into account whitespace?
def
get_answers
(
cls
,
doc
,
turn_id
):
return
get_answers
(
doc
,
len
(
doc
[
"questions"
]))
# get answers and valid alternatives
answers
=
[]
answer_forturn
=
doc
[
"answers"
][
turn_id
-
1
][
"input_text"
]
answers
.
append
(
answer_forturn
)
additionals
=
doc
.
get
(
"additional_answers"
)
if
additionals
:
for
key
in
additionals
:
additional_answer_for_turn
=
additionals
[
key
][
turn_id
-
1
][
"input_text"
]
if
additional_answer_for_turn
.
upper
()
not
in
map
(
str
.
upper
,
answers
):
answers
.
append
(
additional_answer_for_turn
)
return
answers
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
# default to predict last turn
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
])
all_answers
=
self
.
get_answers
(
doc
,
turnid
)
return
all_answers
[
0
]
# ignore alternative answers for now
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
""" Uses RequestFactory to construct Requests and returns an iterable of
...
@@ -61,11 +79,10 @@ class CoQA(Dataset):
...
@@ -61,11 +79,10 @@ class CoQA(Dataset):
language description, as well as the few shot examples, and the question
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
"""
ll_alternative_answers
=
[
requests
=
[]
rf
.
loglikelihood
(
ctx
,
" "
+
answer
)
for
answer
in
get_answers
(
doc
,
len
(
doc
[
"questions"
]))
for
answer
in
self
.
get_answers
(
doc
,
len
(
doc
[
"questions"
])):
]
requests
.
append
(
rf
.
loglikelihood
(
ctx
,
" "
+
answer
))
return
requests
return
ll_alternative_answers
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
...
@@ -77,11 +94,10 @@ class CoQA(Dataset):
...
@@ -77,11 +94,10 @@ class CoQA(Dataset):
:param results:
:param results:
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
gold
s
=
get_answers
(
doc
,
len
(
doc
[
"questions"
]))
gold
=
self
.
get_answers
(
doc
,
len
(
doc
[
"questions"
]))
pred
=
np
.
argmax
(
results
)
pred
=
np
.
argmax
(
results
)
return
{
return
{
"acc"
:
pred
in
golds
,
"acc"
:
int
(
pred
==
gold
)
# "f1": (golds, pred), # TODO: Fix
}
}
def
aggregation
(
self
):
def
aggregation
(
self
):
...
@@ -90,8 +106,9 @@ class CoQA(Dataset):
...
@@ -90,8 +106,9 @@ class CoQA(Dataset):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
functions that aggregate a list of metrics
"""
"""
# TODO: implement evaluation.
return
{
raise
NotImplementedError
(
'Evaluation not implemented'
)
"acc"
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
...
@@ -99,5 +116,6 @@ class CoQA(Dataset):
...
@@ -99,5 +116,6 @@ class CoQA(Dataset):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
whether a higher value of the submetric is better
"""
"""
# TODO: implement evaluation.
return
{
raise
NotImplementedError
(
'Evaluation not implemented'
)
"acc"
:
True
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment