Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1a159d6b
Commit
1a159d6b
authored
Feb 09, 2021
by
Leo Gao
Browse files
Merge branch 'master' of github.com:EleutherAI/lm_evaluation_harness into winograd-fixes
parents
8b038c2a
7614a8f3
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
148 additions
and
42 deletions
+148
-42
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+29
-28
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+6
-0
lm_eval/tasks/arithmetic.py
lm_eval/tasks/arithmetic.py
+4
-1
lm_eval/tasks/piqa.py
lm_eval/tasks/piqa.py
+2
-2
lm_eval/tasks/qa4mre.py
lm_eval/tasks/qa4mre.py
+92
-0
lm_eval/tasks/race.py
lm_eval/tasks/race.py
+6
-3
lm_eval/tasks/sciq.py
lm_eval/tasks/sciq.py
+6
-5
lm_eval/tasks/superglue.py
lm_eval/tasks/superglue.py
+1
-1
main.py
main.py
+2
-2
No files found.
lm_eval/models/gpt2.py
View file @
1a159d6b
...
@@ -7,44 +7,45 @@ from tqdm import tqdm
...
@@ -7,44 +7,45 @@ from tqdm import tqdm
class
GPT2LM
(
LM
):
class
GPT2LM
(
LM
):
def
__init__
(
self
,
device
=
"cpu"
):
def
__init__
(
self
,
device
=
"cpu"
,
pretrained
=
'gpt2'
):
self
.
device
=
torch
.
device
(
device
)
self
.
device
=
torch
.
device
(
device
)
self
.
gpt2
=
transformers
.
GPT2LMHeadModel
.
from_pretrained
(
'gpt2'
).
to
(
self
.
device
)
self
.
gpt2
=
transformers
.
GPT2LMHeadModel
.
from_pretrained
(
pretrained
).
to
(
self
.
device
)
self
.
gpt2
.
eval
()
self
.
gpt2
.
eval
()
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
'gpt2'
)
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
pretrained
)
self
.
tokenizer
.
pad_token
=
"<|endoftext|>"
self
.
tokenizer
.
pad_token
=
"<|endoftext|>"
@
classmethod
@
classmethod
def
create_from_arg_string
(
cls
,
arg_string
):
def
create_from_arg_string
(
cls
,
arg_string
):
args
=
utils
.
simple_parse_args_string
(
arg_string
)
args
=
utils
.
simple_parse_args_string
(
arg_string
)
return
cls
(
device
=
args
.
get
(
"device"
,
"cpu"
))
return
cls
(
device
=
args
.
get
(
"device"
,
"cpu"
)
,
pretrained
=
args
.
get
(
"pretrained"
,
"gpt2"
)
)
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
):
res
=
[]
res
=
[]
# TODO: vectorize properly
with
torch
.
no_grad
():
for
context
,
continuation
in
tqdm
(
requests
):
# TODO: vectorize properly
# when too long to fit in context, truncate from the left
for
context
,
continuation
in
tqdm
(
requests
):
# when too long to fit in context, truncate from the left
if
context
==
""
:
# end of text as context
if
context
==
""
:
context_enc
=
[
50256
]
# end of text as context
else
:
context_enc
=
[
50256
]
context_enc
=
self
.
tokenizer
.
encode
(
context
)
else
:
context_enc
=
self
.
tokenizer
.
encode
(
context
)
continuation_enc
=
self
.
tokenizer
.
encode
(
continuation
)
inp
=
torch
.
tensor
([(
context_enc
+
continuation_enc
)[
-
1024
:]],
dtype
=
torch
.
long
).
to
(
self
.
device
)
continuation_enc
=
self
.
tokenizer
.
encode
(
continuation
)
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
1024
)
inp
=
torch
.
tensor
([(
context_enc
+
continuation_enc
)[
-
1024
:]],
dtype
=
torch
.
long
).
to
(
self
.
device
)
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
1024
)
cont_toks
=
inp
[:,
ctxlen
:]
# [batch, seq]
logits
=
F
.
log_softmax
(
self
.
gpt2
(
inp
)[
0
],
dim
=-
1
)[:,
ctxlen
-
1
:
-
1
]
# [batch, seq, vocab]
cont_toks
=
inp
[:,
ctxlen
:]
# [batch, seq]
logits
=
F
.
log_softmax
(
self
.
gpt2
(
inp
)[
0
],
dim
=-
1
)[:,
ctxlen
-
1
:
-
1
]
# [batch, seq, vocab]
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [batch, seq]
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [batch, seq]
res
.
append
((
float
(
logits
.
sum
()),
bool
(
max_equal
)))
res
.
append
((
float
(
logits
.
sum
()),
bool
(
max_equal
)))
return
res
return
res
...
...
lm_eval/tasks/__init__.py
View file @
1a159d6b
...
@@ -20,6 +20,7 @@ from . import triviaqa
...
@@ -20,6 +20,7 @@ from . import triviaqa
from
.
import
pubmedqa
from
.
import
pubmedqa
from
.
import
sciq
from
.
import
sciq
from
.
import
webqs
from
.
import
webqs
from
.
import
qa4mre
TASK_REGISTRY
=
{
TASK_REGISTRY
=
{
...
@@ -48,8 +49,13 @@ TASK_REGISTRY = {
...
@@ -48,8 +49,13 @@ TASK_REGISTRY = {
"lambada"
:
lambada
.
LAMBADA
,
"lambada"
:
lambada
.
LAMBADA
,
"piqa"
:
piqa
.
PiQA
,
"piqa"
:
piqa
.
PiQA
,
# Science related
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"sciq"
:
sciq
.
SciQ
,
#"qa4mre" : qa4mre.QA4MRE,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
#"triviaqa": triviaqa.TriviaQA,
#"triviaqa": triviaqa.TriviaQA,
"arc_easy"
:
arc
.
ARCEasy
,
"arc_easy"
:
arc
.
ARCEasy
,
...
...
lm_eval/tasks/arithmetic.py
View file @
1a159d6b
...
@@ -56,7 +56,10 @@ class Arithmetic(Task):
...
@@ -56,7 +56,10 @@ class Arithmetic(Task):
return
doc
.
completion
return
doc
.
completion
def
load_doc
(
self
,
doc_json
):
def
load_doc
(
self
,
doc_json
):
return
ArithmeticDoc
(
context
=
doc_json
[
'context'
].
strip
(),
completion
=
doc_json
[
'completion'
].
strip
())
return
ArithmeticDoc
(
context
=
doc_json
[
'context'
].
strip
()
.
replace
(
'
\n\n
'
,
'
\n
'
)
.
replace
(
'Q:'
,
'Question:'
)
.
replace
(
'A:'
,
'Answer:'
),
completion
=
doc_json
[
'completion'
])
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
ll
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
doc
.
completion
)
ll
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
doc
.
completion
)
...
...
lm_eval/tasks/piqa.py
View file @
1a159d6b
...
@@ -28,8 +28,8 @@ class PiQA(HFTask):
...
@@ -28,8 +28,8 @@ class PiQA(HFTask):
return
" "
+
solutions
[
doc
[
"label"
]]
return
" "
+
solutions
[
doc
[
"label"
]]
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
ll_1
,
_
=
rf
.
loglikelihood
(
ctx
,
doc
[
'sol1'
])
ll_1
,
_
=
rf
.
loglikelihood
(
ctx
,
" "
+
doc
[
'sol1'
])
ll_2
,
_
=
rf
.
loglikelihood
(
ctx
,
doc
[
'sol2'
])
ll_2
,
_
=
rf
.
loglikelihood
(
ctx
,
" "
+
doc
[
'sol2'
])
return
ll_1
,
ll_2
return
ll_1
,
ll_2
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
...
...
lm_eval/tasks/qa4mre.py
0 → 100644
View file @
1a159d6b
import
os
import
numpy
as
np
from
best_download
import
download_file
from
lm_eval.base
import
MultipleChoiceTask
,
rf
,
mean
import
xml.etree.ElementTree
as
ET
import
random
class
QA4MRE
(
MultipleChoiceTask
):
YEAR
=
None
def
download
(
self
):
year
=
self
.
YEAR
lang
=
"EN"
base_path
=
(
"http://nlp.uned.es/clef-qa/repository/js/scripts/downloadFile.php?"
"file=/var/www/html/nlp/clef-qa/repository/resources/QA4MRE/"
)
# TODO: add side tasks?
variable_year_path
=
{
2011
:
'2011/Training_Data/Goldstandard/'
,
2012
:
'2012/Main_Task/Training_Data/Goldstandard/Used_in_Evaluation/'
,
2013
:
'2013/Main_Task/Training_Data/Goldstandard/'
}
sha256sums
=
{
2011
:
"6d2524952a3a015f2a82df785b85b5578681e3602ec276b4e72c01f4ebc50034"
,
2012
:
"f9edaf408f8ac93f89a643a0d0b19263a1bb5ce64f19b2af10df279a656dfb24"
,
2013
:
"c60e5aa4ec77e0493ef0b11d46bd1d74d58a499a3a2f871b8cf3af9536f0f094"
,
}
vpath
=
variable_year_path
[
year
]
url_path
=
f
"
{
base_path
}{
vpath
}
QA4MRE-
{
year
}
-
{
lang
}
_GS.xml"
if
not
os
.
path
.
exists
(
"data/qa4mre"
):
os
.
mkdir
(
"data/qa4mre"
)
if
not
os
.
path
.
isfile
(
f
"data/qa4mre/QA4MRE-
{
year
}
-
{
lang
}
"
):
download_file
(
url_path
,
f
"data/qa4mre/QA4MRE-
{
year
}
-
{
lang
}
_GS.xml"
,
checksum
=
sha256sums
[
year
],
)
def
has_training_docs
(
self
):
return
False
def
has_validation_docs
(
self
):
return
False
def
has_test_docs
(
self
):
return
True
def
fewshot_examples
(
self
,
k
):
# Since only test docs sample from test docs
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
test_docs
())
return
random
.
sample
(
self
.
_training_docs
,
k
)
def
_convert_standard
(
self
,
question
):
choices
=
[
i
.
text
for
i
in
question
.
iter
(
'answer'
)]
out_doc
=
{
"query"
:
question
.
find
(
'q_str'
).
text
,
"choices"
:
choices
,
"gold"
:
int
(
question
.
find
(
"./answer[@correct='Yes']"
).
attrib
[
"a_id"
])
-
1
,
}
return
out_doc
def
load_docs
(
self
,
textfilename
,
tfds
=
False
):
tree
=
ET
.
parse
(
textfilename
)
root
=
tree
.
getroot
()
# TODO: context is much larger than the context sometimes
# at the moment, it just gets left-truncated by LM automatically, and maybe that's good enough?
for
reading_test
in
root
.
iter
(
'reading-test'
):
src
=
reading_test
[
0
].
text
src
=
src
.
strip
().
replace
(
"
\'
"
,
"'"
)
for
qid
,
question
in
enumerate
(
reading_test
.
iter
(
'q'
)):
out_doc
=
self
.
_convert_standard
(
question
)
out_doc
[
'source'
]
=
src
yield
out_doc
def
fewshot_description
(
self
):
return
""
def
test_docs
(
self
):
return
self
.
load_docs
(
f
"data/qa4mre/QA4MRE-
{
self
.
YEAR
}
-EN_GS.xml"
)
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: {}
\n
Answer:"
.
format
(
doc
[
"source"
],
doc
[
"query"
])
class
QA4MRE_2011
(
QA4MRE
):
YEAR
=
2011
class
QA4MRE_2012
(
QA4MRE
):
YEAR
=
2012
class
QA4MRE_2013
(
QA4MRE
):
YEAR
=
2013
lm_eval/tasks/race.py
View file @
1a159d6b
...
@@ -82,9 +82,12 @@ class RACE(HFTask):
...
@@ -82,9 +82,12 @@ class RACE(HFTask):
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
text
=
'Article: '
+
doc
[
'article'
]
+
'
\n\n
'
text
=
'Article: '
+
doc
[
'article'
]
+
'
\n\n
'
for
problem
in
doc
[
'problems'
][:
-
1
]:
for
problem
in
doc
[
'problems'
][:
-
1
]:
assert
problem
[
'question'
][
-
6
:]
==
' _ .'
if
problem
[
'question'
][
-
6
:]
==
' _ .'
:
text
+=
problem
[
'question'
][
-
5
:]
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
text
+=
problem
[
'question'
][
-
5
:]
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
else
:
question
=
'Question: '
+
problem
[
'question'
]
+
'
\n
'
answer
=
'Answer: '
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
text
+=
question
+
answer
text
+=
self
.
last_problem
(
doc
)[
'question'
]
text
+=
self
.
last_problem
(
doc
)[
'question'
]
return
text
return
text
...
...
lm_eval/tasks/sciq.py
View file @
1a159d6b
...
@@ -3,6 +3,7 @@ import json
...
@@ -3,6 +3,7 @@ import json
from
..utils
import
sh
from
..utils
import
sh
from
lm_eval.base
import
MultipleChoiceTask
,
rf
,
mean
from
lm_eval.base
import
MultipleChoiceTask
,
rf
,
mean
import
zipfile
import
zipfile
from
best_download
import
download_file
class
SciQ
(
MultipleChoiceTask
):
class
SciQ
(
MultipleChoiceTask
):
...
@@ -10,9 +11,11 @@ class SciQ(MultipleChoiceTask):
...
@@ -10,9 +11,11 @@ class SciQ(MultipleChoiceTask):
def
download
(
self
):
def
download
(
self
):
if
not
os
.
path
.
exists
(
'data/sciq'
):
if
not
os
.
path
.
exists
(
'data/sciq'
):
os
.
mkdir
(
'data/sciq'
)
os
.
mkdir
(
'data/sciq'
)
sh
((
download_file
(
"wget https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip -O data/sciq/SciQ.zip"
'https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip'
,
))
'data/sciq/SciQ.zip'
,
'7f3312f6ac6b09970b32942d106a8c44ec0dad46a0369f17d635aff8e348a87c'
,
)
with
zipfile
.
ZipFile
(
"data/sciq/SciQ.zip"
,
"r"
)
as
zf
:
with
zipfile
.
ZipFile
(
"data/sciq/SciQ.zip"
,
"r"
)
as
zf
:
zf
.
extractall
(
"data/sciq/"
)
zf
.
extractall
(
"data/sciq/"
)
...
@@ -48,8 +51,6 @@ class SciQ(MultipleChoiceTask):
...
@@ -48,8 +51,6 @@ class SciQ(MultipleChoiceTask):
yield
self
.
_convert_standard
(
record
)
yield
self
.
_convert_standard
(
record
)
def
fewshot_description
(
self
):
def
fewshot_description
(
self
):
# Average ctx length in labelled dataset is 238.9
# 2 few-shot exmamples pushes it beyond context window
return
""
return
""
def
training_docs
(
self
):
def
training_docs
(
self
):
...
...
lm_eval/tasks/superglue.py
View file @
1a159d6b
...
@@ -218,7 +218,7 @@ class MultiRC(HFTask):
...
@@ -218,7 +218,7 @@ class MultiRC(HFTask):
return
f
"
{
doc
[
'paragraph'
]
}
\n
Question:
{
doc
[
'question'
]
}
\n
Answer:"
return
f
"
{
doc
[
'paragraph'
]
}
\n
Question:
{
doc
[
'question'
]
}
\n
Answer:"
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
self
.
format_answer
(
answer
=
doc
[
"answer"
],
label
=
doc
[
"label"
])
return
" "
+
self
.
format_answer
(
answer
=
doc
[
"answer"
],
label
=
doc
[
"label"
])
@
staticmethod
@
staticmethod
def
format_answer
(
answer
,
label
):
def
format_answer
(
answer
,
label
):
...
...
main.py
View file @
1a159d6b
...
@@ -20,7 +20,7 @@ def parse_args():
...
@@ -20,7 +20,7 @@ def parse_args():
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1234
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1234
)
parser
.
add_argument
(
'--output_path'
,
default
=
None
)
parser
.
add_argument
(
'--output_path'
,
default
=
None
)
parser
.
add_argument
(
'--limit'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--limit'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--cache'
,
action
=
"store_true"
)
parser
.
add_argument
(
'--
no_
cache'
,
action
=
"store_true"
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
main
():
def
main
():
...
@@ -31,7 +31,7 @@ def main():
...
@@ -31,7 +31,7 @@ def main():
lm
=
models
.
get_model
(
args
.
model
).
create_from_arg_string
(
args
.
model_args
)
lm
=
models
.
get_model
(
args
.
model
).
create_from_arg_string
(
args
.
model_args
)
if
args
.
cache
:
if
not
args
.
no_
cache
:
lm
=
base
.
CachingLM
(
lm
,
'lm_cache/'
+
args
.
model
+
'_'
+
args
.
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
)
+
'.db'
)
lm
=
base
.
CachingLM
(
lm
,
'lm_cache/'
+
args
.
model
+
'_'
+
args
.
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
)
+
'.db'
)
if
args
.
tasks
==
"all_tasks"
:
if
args
.
tasks
==
"all_tasks"
:
task_names
=
tasks
.
ALL_TASKS
task_names
=
tasks
.
ALL_TASKS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment