Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
c155698f
Unverified
Commit
c155698f
authored
Feb 14, 2021
by
Leo Gao
Committed by
GitHub
Feb 14, 2021
Browse files
Merge pull request #133 from EleutherAI/fazz/refactor-task-coqa
CoQa Task Initial Implementation
parents
1e5194a2
431e53de
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
111 additions
and
43 deletions
+111
-43
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+1
-1
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+2
-1
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+105
-41
tests/test_tasks.py
tests/test_tasks.py
+3
-0
No files found.
lm_eval/models/gpt2.py
View file @
c155698f
...
@@ -60,7 +60,7 @@ class GPT2LM(LM):
...
@@ -60,7 +60,7 @@ class GPT2LM(LM):
for
context
,
until
in
tqdm
(
requests
):
for
context
,
until
in
tqdm
(
requests
):
if
isinstance
(
until
,
str
):
until
=
[
until
]
if
isinstance
(
until
,
str
):
until
=
[
until
]
context_enc
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
context
)]).
to
(
self
.
device
)
context_enc
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
context
)
[
self
.
MAX_GEN_TOKS
-
1024
:]
]).
to
(
self
.
device
)
primary_until
,
=
self
.
tokenizer
.
encode
(
until
[
0
])
primary_until
,
=
self
.
tokenizer
.
encode
(
until
[
0
])
...
...
lm_eval/tasks/__init__.py
View file @
c155698f
...
@@ -5,6 +5,7 @@ import sacrebleu
...
@@ -5,6 +5,7 @@ import sacrebleu
from
.
import
superglue
from
.
import
superglue
from
.
import
glue
from
.
import
glue
from
.
import
arc
from
.
import
arc
from
.
import
coqa
from
.
import
race
from
.
import
race
from
.
import
webqs
from
.
import
webqs
from
.
import
anli
from
.
import
anli
...
@@ -81,7 +82,7 @@ TASK_REGISTRY = {
...
@@ -81,7 +82,7 @@ TASK_REGISTRY = {
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
# Order by benchmark/genre?
# Order by benchmark/genre?
"coqa"
:
coqa
.
CoQA
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada"
:
lambada
.
LAMBADA
,
"piqa"
:
piqa
.
PiQA
,
"piqa"
:
piqa
.
PiQA
,
...
...
lm_eval/tasks/coqa.py
View file @
c155698f
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import
os
import
json
import
json
import
random
from
lm_eval.base
import
Task
,
rf
,
mean
from
lm_eval.base
import
Task
from
..utils
import
sh
from
..utils
import
sh
from
itertools
import
zip_longest
import
transformers.data.metrics.squad_metrics
as
squad_metrics
import
collections
import
datasets
import
numpy
as
np
from
lm_eval.base
import
rf
,
mean
from
.
common
import
HFTask
from
tqdm
import
tqdm
import
string
,
re
class
CoQA
(
Task
):
class
CoQA
(
Task
):
def
__init__
(
self
):
self
.
download
()
def
download
(
self
):
def
download
(
self
):
#TODO: don't download if files already there
coqa_train_filepath
=
'data/coqa/coqa-train-v1.0.json'
sh
(
"""
coqa_dev_filepath
=
'data/coqa/coqa-dev-v1.0.json'
mkdir -p data/coqa
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
sh
(
"""mkdir -p data/coqa"""
)
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json
if
not
os
.
path
.
exists
(
coqa_train_filepath
):
"""
)
sh
(
"""wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O """
+
coqa_train_filepath
)
if
not
os
.
path
.
exists
(
coqa_dev_filepath
):
sh
(
"""wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O """
+
coqa_dev_filepath
)
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
return
True
return
True
...
@@ -36,16 +43,71 @@ class CoQA(Task):
...
@@ -36,16 +43,71 @@ class CoQA(Task):
pass
pass
def
fewshot_description
(
self
):
def
fewshot_description
(
self
):
# TODO: figure out description
return
"Given a passage and a conversation so far, answer the next question in the conversation."
return
""
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
# TODO: implement.
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
raise
NotImplementedError
(
'doc_to_text not implemented'
)
# and a question qi, the task is to predict the answer ai
doc_text
=
doc
[
"story"
]
+
'
\n\n
'
def
doc_to_target
(
self
,
doc
):
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
],
doc
[
"answers"
][:
-
1
]):
# omit target answer ai
# TODO: implement.
question
=
f
"Q:
{
q
[
'input_text'
]
}
"
+
'
\n\n
'
raise
NotImplementedError
(
'doc_to_target not implemented'
)
answer
=
f
"A:
{
a
[
'input_text'
]
}
"
+
'
\n\n
'
if
a
is
not
None
else
"A:"
doc_text
+=
question
+
answer
return
doc_text
@
classmethod
def
get_answers
(
cls
,
doc
,
turn_id
):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers
=
[]
answer_forturn
=
doc
[
"answers"
][
turn_id
-
1
][
"input_text"
]
answers
.
append
(
answer_forturn
)
additional_answers
=
doc
.
get
(
"additional_answers"
)
if
additional_answers
:
for
key
in
additional_answers
:
additional_answer_for_turn
=
additional_answers
[
key
][
turn_id
-
1
][
"input_text"
]
if
additional_answer_for_turn
.
lower
()
not
in
map
(
str
.
lower
,
answers
):
answers
.
append
(
additional_answer_for_turn
)
return
answers
@
classmethod
def
get_answer_choice
(
self
,
raw_text
):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if
raw_text
==
"unknown"
:
return
'0'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"yes"
:
return
'1'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"no"
:
return
'2'
return
'3'
# Not a yes/no question
@
staticmethod
def
compute_scores
(
gold_list
,
pred
):
# tests for exact match and on the normalised answer (compute_exact)
# test for overlap (compute_f1)
f1_sum
=
0.0
em_sum
=
0.0
if
len
(
gold_list
)
>
1
:
for
i
in
range
(
len
(
gold_list
)):
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
# predictions compared against (n) golds and take maximum
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_answers
)
else
:
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_list
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_list
)
return
{
'em'
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
'f1'
:
f1_sum
/
max
(
1
,
len
(
gold_list
))}
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
# Default to prediction of last turn.
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
])
raw_text
=
doc
[
'answers'
][
turnid
-
1
][
"input_text"
]
return
" "
+
raw_text
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
""" Uses RequestFactory to construct Requests and returns an iterable of
...
@@ -58,8 +120,8 @@ class CoQA(Task):
...
@@ -58,8 +120,8 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
"""
# TODO: implement evaluation.
cont_request
=
rf
.
greedy_until
(
ctx
,
[
'
\n
'
])
r
aise
NotImplementedError
(
'Evaluation not implemented'
)
r
eturn
cont_request
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
...
@@ -71,23 +133,25 @@ class CoQA(Task):
...
@@ -71,23 +133,25 @@ class CoQA(Task):
:param results:
:param results:
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
# TODO: implement evaluation.
turn_id
=
len
(
doc
[
"questions"
])
raise
NotImplementedError
(
'Evaluation not implemented'
)
gold_list
=
self
.
get_answers
(
doc
,
turn_id
)
pred
=
results
[
0
]
def
aggregation
(
self
):
scores
=
self
.
compute_scores
(
gold_list
,
pred
)
"""
:returns: {str: [float] -> float}
return
{
A dictionary where keys are the names of submetrics and values are
"f1"
:
scores
[
'f1'
],
functions that aggregate a list of metrics
"em"
:
scores
[
'em'
],
"""
}
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
return
{
:returns: {str: bool}
"f1"
:
True
,
A dictionary where keys are the names of submetrics and values are
"em"
:
True
,
whether a higher value of the submetric is better
}
"""
# TODO: implement evaluation.
def
aggregation
(
self
):
raise
NotImplementedError
(
'Evaluation not implemented'
)
return
{
"f1"
:
mean
,
"em"
:
mean
,
}
tests/test_tasks.py
View file @
c155698f
...
@@ -76,6 +76,9 @@ def test_documents_and_requests(taskname, Task):
...
@@ -76,6 +76,9 @@ def test_documents_and_requests(taskname, Task):
reqs
=
task
.
construct_requests
(
doc
,
txt
)
reqs
=
task
.
construct_requests
(
doc
,
txt
)
# construct_requests can return just one request
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
reqs
=
[
reqs
]
# todo: mock lm after refactoring evaluator.py to not be a mess
# todo: mock lm after refactoring evaluator.py to not be a mess
for
req
in
reqs
:
for
req
in
reqs
:
assert
isinstance
(
req
,
base
.
Request
)
assert
isinstance
(
req
,
base
.
Request
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment