Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
60a6fd8c
Commit
60a6fd8c
authored
Jan 27, 2021
by
Leo Gao
Browse files
Implement unit testing and fix lots of problems with tasks
parent
693c19e2
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
103 additions
and
66 deletions
+103
-66
lm_eval/models/dummy.py
lm_eval/models/dummy.py
+18
-6
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+0
-1
lm_eval/models/gpt3.py
lm_eval/models/gpt3.py
+2
-0
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+19
-17
lm_eval/tasks/arithmetic.py
lm_eval/tasks/arithmetic.py
+1
-1
lm_eval/tasks/common.py
lm_eval/tasks/common.py
+2
-0
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+1
-1
lm_eval/tasks/race.py
lm_eval/tasks/race.py
+6
-9
lm_eval/tasks/sat.py
lm_eval/tasks/sat.py
+2
-0
lm_eval/tasks/storycloze.py
lm_eval/tasks/storycloze.py
+3
-3
lm_eval/tasks/superglue.py
lm_eval/tasks/superglue.py
+34
-13
lm_eval/tasks/triviaqa.py
lm_eval/tasks/triviaqa.py
+8
-10
lm_eval/tasks/wsc273.py
lm_eval/tasks/wsc273.py
+7
-5
No files found.
lm_eval/models/dummy.py
View file @
60a6fd8c
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import
random
from
lm_eval.base
import
LM
from
.
import
MODEL_REGISTRY
@
MODEL_REGISTRY
.
register
(
"dummy"
)
class
DummyLM
(
LM
):
def
__init__
(
self
):
pass
@
classmethod
def
create_from_arg_string
(
cls
,
arg_string
):
return
cls
()
def
loglikelihood
(
self
,
requests
):
res
=
[]
for
_
in
requests
:
res
.
append
((
-
random
.
random
(),
False
))
def
loglikelihood
(
self
,
context
,
continuation
):
return
0.0
return
res
def
greedy_until
(
self
,
requests
):
# TODO: implement
pass
\ No newline at end of file
lm_eval/models/gpt2.py
View file @
60a6fd8c
...
...
@@ -19,7 +19,6 @@ class GPT2LM(LM):
return
cls
(
device
=
args
.
get
(
"device"
,
"cpu"
))
def
loglikelihood
(
self
,
requests
):
print
(
requests
)
res
=
[]
# TODO: vectorize properly
for
context
,
continuation
in
tqdm
(
requests
):
...
...
lm_eval/models/gpt3.py
View file @
60a6fd8c
...
...
@@ -32,6 +32,8 @@ class GPT3LM(LM):
return
cls
(
engine
=
args
.
get
(
"engine"
,
"davinci"
))
def
loglikelihood
(
self
,
context
,
continuation
):
# TODO: implement new framework
import
openai
context_enc
=
self
.
tokenizer
.
encode
(
context
)
...
...
lm_eval/tasks/__init__.py
View file @
60a6fd8c
...
...
@@ -23,7 +23,7 @@ TASK_REGISTRY = {
"rte"
:
glue
.
RTE
,
"qnli"
:
glue
.
QNLI
,
"qqp"
:
glue
.
QQP
,
"stsb"
:
glue
.
STSB
,
#
"stsb": glue.STSB,
# not implemented yet
"sst"
:
glue
.
SST
,
"wnli"
:
glue
.
WNLI
,
# SuperGLUE
...
...
@@ -33,23 +33,25 @@ TASK_REGISTRY = {
"multirc"
:
superglue
.
MultiRC
,
"record"
:
superglue
.
ReCoRD
,
"wic"
:
superglue
.
WordsInContext
,
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
#"wsc": superglue.SGWinogradSchemaChallenge, # not implemented yet
# Order by benchmark/genre?
"arc_easy"
:
arc
.
ARCEasy
,
"arc_challenge"
:
arc
.
ARCChallenge
,
"quac"
:
quac
.
QuAC
,
"hellaswag"
:
hellaswag
.
HellaSwag
,
"openbookqa"
:
openbookqa
.
OpenBookQA
,
"sat"
:
sat
.
SATAnalogies
,
"squad"
:
squad
.
SQuAD
,
"race"
:
race
.
RACE
,
"naturalqs"
:
naturalqs
.
NaturalQs
,
"webqs"
:
webqs
.
WebQs
,
"wsc273"
:
wsc273
.
WinogradSchemaChallenge273
,
"winogrande"
:
winogrande
.
Winogrande
,
"anli_r1"
:
anli
.
ANLIRound1
,
"anli_r2"
:
anli
.
ANLIRound2
,
"anli_r3"
:
anli
.
ANLIRound3
,
# "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet
# "quac": quac.QuAC, # not implemented yet
# "hellaswag": hellaswag.HellaSwag, # not implemented yet
# "openbookqa": openbookqa.OpenBookQA, # not implemented yet
# "sat": sat.SATAnalogies, # not implemented yet
# "squad": squad.SQuAD, # not implemented yet
# "race": race.RACE, # not implemented yet
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
# "webqs": webqs.WebQs, # not implemented yet
# "wsc273": wsc273.WinogradSchemaChallenge273, # not implemented yet
# "winogrande": winogrande.Winogrande, # not implemented yet
# "anli_r1": anli.ANLIRound1, # not implemented yet
# "anli_r2": anli.ANLIRound2, # not implemented yet
# "anli_r3": anli.ANLIRound3, # not implemented yet
# arithmetic
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
"arithmetic_2ds"
:
arithmetic
.
Arithmetic2DMinus
,
...
...
lm_eval/tasks/arithmetic.py
View file @
60a6fd8c
...
...
@@ -12,7 +12,6 @@ class Arithmetic(Dataset):
def
__init__
(
self
):
super
().
__init__
()
self
.
set_docs
()
def
download
(
self
):
file_name
,
checksum
=
self
.
get_file_download_info
()
...
...
@@ -20,6 +19,7 @@ class Arithmetic(Dataset):
if
not
os
.
path
.
exists
(
self
.
directory
):
os
.
makedirs
(
self
.
directory
)
download_file
(
url
,
self
.
directory
+
file_name
,
checksum
)
self
.
set_docs
()
@
abc
.
abstractmethod
def
get_file_download_info
(
self
):
...
...
lm_eval/tasks/common.py
View file @
60a6fd8c
...
...
@@ -11,6 +11,8 @@ class HFTask(Dataset):
def
__init__
(
self
):
super
().
__init__
()
self
.
_training_docs
=
None
def
download
(
self
):
self
.
data
=
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)
def
has_training_docs
(
self
):
...
...
lm_eval/tasks/drop.py
View file @
60a6fd8c
...
...
@@ -11,7 +11,7 @@ class DROP(Dataset):
DATAFOLDER
=
Path
(
__file__
).
parent
/
"../../data/drop"
def
__init__
(
self
):
s
elf
.
download
()
s
uper
().
__init__
()
def
has_training_docs
(
self
):
"""Whether the task has a training set"""
...
...
lm_eval/tasks/race.py
View file @
60a6fd8c
...
...
@@ -54,16 +54,13 @@ class RACE(HFTask):
# TODO: figure out description
return
""
def
doc_to_text
(
self
,
doc
,
include_target
=
True
):
r
=
"Article:
\n
"
+
doc
[
'article'
]
+
'
\n\n
'
def
doc_to_text
(
self
,
doc
):
# TODO: implement
pass
r
+=
doc
[
'problems'
]
>>
apply
(
enumerate
)
>>
each
(
lambda
x
:
'Q: '
+
x
[
1
][
'question'
]
+
'
\n\n
A:'
+
((
' '
+
x
[
1
][
'options'
][[
'A'
,
'B'
,
'C'
,
'D'
].
index
(
x
[
1
][
'answer'
])])
\
if
x
[
0
]
!=
len
(
doc
[
'problems'
])
-
1
or
include_target
else
''
))
\
>>
join
(
'
\n\n
'
)
return
r
def
doc_to_target
(
self
,
doc
):
# TODO: implement
pass
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
...
...
lm_eval/tasks/sat.py
View file @
60a6fd8c
...
...
@@ -9,6 +9,8 @@ from ..utils import sh
class
SATAnalogies
(
Dataset
):
NEEDS_MANUAL_DL
=
True
def
__init__
(
self
):
super
().
__init__
()
...
...
lm_eval/tasks/storycloze.py
View file @
60a6fd8c
...
...
@@ -5,8 +5,8 @@ from ..utils import sh
import
csv
class
StoryCloze
(
Dataset
):
def
__init__
(
self
):
self
.
download
()
NEEDS_MANUAL_DL
=
True
def
download
(
self
):
#TODO: replace with Eye link
pass
...
...
@@ -30,7 +30,7 @@ class StoryCloze(Dataset):
def
validation_docs
(
self
):
return
self
.
load_doc
(
"data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv"
)
return
self
.
load_doc
(
"data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv"
)
def
test_docs
(
self
):
return
self
.
load_doc
(
"data/storycloze/cloze_test_test__winter2018-cloze_test_ALL_test - 1.csv"
)
...
...
lm_eval/tasks/superglue.py
View file @
60a6fd8c
...
...
@@ -75,6 +75,7 @@ class CommitmentBank(HFTask):
return
True
def
fewshot_description
(
self
):
# TODO: figure out actual description
return
"Given a premise and a hypothesis, classify whether the author of the premise is committed"
\
"to the truth of the hypothesis. The three possible labels are true, false or neither."
...
...
@@ -145,6 +146,7 @@ class Copa(HFTask):
return
True
def
fewshot_description
(
self
):
# TODO: figure out actual description
return
"Given a premise and one alternative with a causal relation to the premise and another without,"
\
"choose the more plausible alternative"
...
...
@@ -208,6 +210,7 @@ class MultiRC(HFTask):
return
True
def
fewshot_description
(
self
):
# TODO: figure out actual description
return
"READING COMPREHENSION ANSWER KEY"
def
doc_to_text
(
self
,
doc
):
...
...
@@ -260,24 +263,37 @@ class ReCoRD(HFTask):
def
has_test_docs
(
self
):
return
True
def
fewshot_description
(
self
):
# TODO: figure out actual description
return
""
def
training_docs
(
self
):
# In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing.
# Each doc consists of multiple answer candidates, each of which is scored yes/no.
# Hence, we one "doc" for each (context + passage, answer) pair.
# Moreover, we only use the correct answers for context packing
# (This is not an issue for evaluation, where we can directly score multiple candidates at once).
if
self
.
has_training_docs
():
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
[]
for
doc
in
self
.
data
[
"train"
]:
for
entity
in
list
(
set
(
doc
[
"entities"
])):
self
.
_training_docs
.
append
({
"passage"
:
doc
[
"passage"
],
"query"
:
doc
[
"query"
],
"entity"
:
entity
,
"label"
:
entity
in
doc
[
"answers"
],
})
return
self
.
_training_docs
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
[]
for
doc
in
self
.
data
[
"train"
]:
for
entity
in
list
(
set
(
doc
[
"entities"
])):
self
.
_training_docs
.
append
({
"passage"
:
doc
[
"passage"
],
"query"
:
doc
[
"query"
],
"entity"
:
entity
,
"label"
:
entity
in
doc
[
"answers"
],
})
return
self
.
_training_docs
def
validation_docs
(
self
):
for
doc
in
self
.
data
[
"validation"
]:
for
entity
in
list
(
set
(
doc
[
"entities"
])):
yield
{
"passage"
:
doc
[
"passage"
],
"query"
:
doc
[
"query"
],
"entity"
:
entity
,
"label"
:
entity
in
doc
[
"answers"
],
}
def
doc_to_text
(
self
,
doc
):
initial_text
,
*
highlights
=
doc
[
"passage"
].
strip
().
split
(
"
\n
@highlight
\n
"
)
...
...
@@ -296,7 +312,7 @@ class ReCoRD(HFTask):
def
construct_requests
(
self
,
doc
,
ctx
):
requests
=
[
rf
.
loglikelihood
(
ctx
,
self
.
format_answer
(
query
=
doc
[
"query"
],
entity
=
entity
))
for
entity
in
doc
[
"entit
ies
"
]
for
entity
in
doc
[
"entit
y
"
]
]
return
requests
...
...
@@ -342,6 +358,10 @@ class WordsInContext(HFTask):
def
has_test_docs
(
self
):
return
True
def
fewshot_description
(
self
):
# TODO: figure out actual description
return
""
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
{}
\n
Question: Is the word '{}' used in the same way in the"
\
" two sentences above?
\n
answer:"
.
format
(
...
...
@@ -405,6 +425,7 @@ class SGWinogradSchemaChallenge(HFTask):
return
self
.
_training_docs
def
fewshot_description
(
self
):
# TODO: figure out actual description
return
"Final Exam with Answer Key
\n
"
\
"Instructions: Please carefully read the following passages. "
\
"For each passage, you must identify which noun the pronoun marked in *bold*"
\
...
...
lm_eval/tasks/triviaqa.py
View file @
60a6fd8c
import
os
import
json
import
random
from
lm_eval.base
import
Dataset
from
..utils
import
sh
class
TriviaQA
(
Dataset
):
def
__init__
(
self
):
self
.
download
()
def
download
(
self
):
#pass
#TODO: don't download if files already there
sh
(
"""
mkdir -p data/triviaqa
wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz -O data/triviaqa/trivia_qa-unfiltered.tar.gz
tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz
mv triviaqa-unfiltered/ data/triviaqa/
"""
)
if
not
os
.
path
.
exists
(
'data/triviaqa'
):
sh
(
"""
mkdir -p data/triviaqa
wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz -O data/triviaqa/trivia_qa-unfiltered.tar.gz
tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz
mv triviaqa-unfiltered/ data/triviaqa/
"""
)
def
has_training_docs
(
self
):
return
True
...
...
lm_eval/tasks/wsc273.py
View file @
60a6fd8c
...
...
@@ -71,12 +71,14 @@ class WinogradSchemaChallenge273(Dataset):
docs
.
append
(
doc
)
return
docs
def
doc_to_text
(
self
,
doc
,
include_target
=
True
):
# WSC273 is currently only writing out full examples. Partial evaluation needs implementing.
text
=
doc
[
'completions'
][
'T'
]
+
' True. '
+
doc
[
'completions'
][
'F'
]
+
' False.'
return
text
def
doc_to_text
(
self
,
doc
):
# TODO: implement
pass
def
doc_to_target
(
self
,
doc
):
# TODO: implement
pass
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment