Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1f8a8c1d
Commit
1f8a8c1d
authored
Jun 11, 2022
by
jon-tow
Browse files
Merge branch 'master' of
https://github.com/EleutherAI/lm-evaluation-harness
into remove-dataset
parents
b4c0275d
b0acb337
Changes
255
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1018 additions
and
750 deletions
+1018
-750
lm_eval/tasks/quac.py
lm_eval/tasks/quac.py
+123
-106
lm_eval/tasks/race.py
lm_eval/tasks/race.py
+53
-42
lm_eval/tasks/sat.py
lm_eval/tasks/sat.py
+13
-5
lm_eval/tasks/sciq.py
lm_eval/tasks/sciq.py
+8
-2
lm_eval/tasks/squad.py
lm_eval/tasks/squad.py
+219
-163
lm_eval/tasks/storycloze.py
lm_eval/tasks/storycloze.py
+27
-21
lm_eval/tasks/superglue.py
lm_eval/tasks/superglue.py
+72
-98
lm_eval/tasks/translation.py
lm_eval/tasks/translation.py
+28
-9
lm_eval/tasks/triviaqa.py
lm_eval/tasks/triviaqa.py
+12
-10
lm_eval/tasks/truthfulqa.py
lm_eval/tasks/truthfulqa.py
+63
-69
lm_eval/tasks/unscramble.py
lm_eval/tasks/unscramble.py
+9
-9
lm_eval/tasks/webqs.py
lm_eval/tasks/webqs.py
+13
-11
lm_eval/tasks/wikitext.py
lm_eval/tasks/wikitext.py
+5
-2
lm_eval/tasks/winogrande.py
lm_eval/tasks/winogrande.py
+132
-132
lm_eval/tasks/wsc273.py
lm_eval/tasks/wsc273.py
+24
-13
lm_eval/utils.py
lm_eval/utils.py
+47
-32
main.py
main.py
+54
-19
pile_statistics.json
pile_statistics.json
+37
-0
scripts/clean_training_data/README.md
scripts/clean_training_data/README.md
+6
-7
scripts/clean_training_data/compress_and_package.py
scripts/clean_training_data/compress_and_package.py
+73
-0
No files found.
lm_eval/tasks/quac.py
View file @
1f8a8c1d
"""
"""
QuAC: Question Answering in Context
QuAC: Question Answering in Context
https://arxiv.org/abs/1808.07036
https://arxiv.org/abs/1808.07036
Question Answering in Context (QuAC) is a dataset for modeling, understanding, and
Question Answering in Context (QuAC) is a dataset for modeling, understanding, and
participating in information seeking dialog. Data instances consist of an interactive
participating in information seeking dialog. Data instances consist of an interactive
dialog between two crowd workers: (1) a student who poses a sequence of freeform
dialog between two crowd workers: (1) a student who poses a sequence of freeform
questions to learn as much as possible about a hidden Wikipedia text, and (2)
questions to learn as much as possible about a hidden Wikipedia text, and (2)
a teacher who answers the questions by providing short excerpts (spans) from the text.
a teacher who answers the questions by providing short excerpts (spans) from the text.
Homepage: https://quac.ai/
Homepage: https://quac.ai/
"""
"""
import
inspect
import
inspect
import
lm_eval.datasets.quac.quac
import
lm_eval.datasets.quac.quac
from
lm_eval.base
import
Task
from
lm_eval.base
import
Task
_CITATION
=
"""
_CITATION
=
"""
@article{choi2018quac,
@article{choi2018quac,
title={Quac: Question answering in context},
title={Quac: Question answering in context},
author={Choi, Eunsol and He, He and Iyyer, Mohit and Yatskar, Mark and Yih, Wen-tau and Choi, Yejin and Liang, Percy and Zettlemoyer, Luke},
author={Choi, Eunsol and He, He and Iyyer, Mohit and Yatskar, Mark and Yih, Wen-tau and Choi, Yejin and Liang, Percy and Zettlemoyer, Luke},
journal={arXiv preprint arXiv:1808.07036},
journal={arXiv preprint arXiv:1808.07036},
year={2018}
year={2018}
}
}
"""
"""
class
QuAC
(
Task
):
class
QuAC
(
Task
):
VERSION
=
0
VERSION
=
0
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
quac
.
quac
)
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
quac
.
quac
)
DATASET_NAME
=
None
DATASET_NAME
=
None
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
return
True
return
True
def
has_validation_docs
(
self
):
def
has_validation_docs
(
self
):
return
True
return
True
def
has_test_docs
(
self
):
def
has_test_docs
(
self
):
return
False
return
False
def
training_docs
(
self
):
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
map
(
self
.
_process_doc
,
self
.
dataset
[
"train"
]))
self
.
_training_docs
=
list
(
map
(
self
.
_process_doc
,
self
.
dataset
[
"train"
]))
return
self
.
_training_docs
return
self
.
_training_docs
def
validation_docs
(
self
):
def
validation_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"validation"
])
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"validation"
])
def
test_docs
(
self
):
def
test_docs
(
self
):
raise
NotImplementedError
(
"QuAC has no test docs."
)
raise
NotImplementedError
(
"QuAC has no test docs."
)
def
_process_doc
(
self
,
doc
):
def
_process_doc
(
self
,
doc
):
doc
[
"title"
]
=
doc
[
'title'
]
+
' - '
+
doc
[
'section_title'
]
doc
[
"title"
]
=
doc
[
"title"
]
+
" - "
+
doc
[
"section_title"
]
return
doc
return
doc
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
'TITLE: '
+
doc
[
'title'
]
+
'
\n
'
+
'PARAGRAPH: '
+
doc
[
'paragraph'
]
+
'
\n\n
'
+
'Q: '
+
doc
[
'question'
]
+
'
\n\n
'
+
'A: '
return
(
"TITLE: "
def
doc_to_target
(
self
,
doc
):
+
doc
[
"title"
]
return
doc
[
'answer'
]
+
"
\n
"
+
"PARAGRAPH: "
def
construct_requests
(
self
,
doc
,
ctx
):
+
doc
[
"paragraph"
]
""" Uses RequestFactory to construct Requests and returns an iterable of
+
"
\n\n
"
Requests which will be sent to the LM.
+
"Q: "
+
doc
[
"question"
]
:param doc:
+
"
\n\n
"
The document as returned from training_docs, validation_docs, or test_docs.
+
"A: "
:param ctx: str
)
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
def
should_decontaminate
(
self
):
part of the document for `doc`.
return
True
"""
# TODO: implement evaluation.
def
doc_to_decontamination_query
(
self
,
doc
):
raise
NotImplementedError
(
'Evaluation not implemented'
)
return
doc
[
"paragraph"
]
def
process_results
(
self
,
doc
,
results
):
def
doc_to_target
(
self
,
doc
):
"""Take a single document and the LM results and evaluates, returning a
return
doc
[
"answer"
]
dict where keys are the names of submetrics and values are the values of
the metric for that one document
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
:param doc:
Requests which will be sent to the LM.
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
:param doc:
The results of the requests created in construct_requests.
The document as returned from training_docs, validation_docs, or test_docs.
"""
:param ctx: str
# TODO: implement evaluation.
The context string, generated by fewshot_context. This includes the natural
raise
NotImplementedError
(
'Evaluation not implemented'
)
language description, as well as the few shot examples, and the question
part of the document for `doc`.
def
aggregation
(
self
):
"""
"""
# TODO: implement evaluation.
:returns: {str: [float] -> float}
raise
NotImplementedError
(
"Evaluation not implemented"
)
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
def
process_results
(
self
,
doc
,
results
):
"""
"""Take a single document and the LM results and evaluates, returning a
# TODO: implement evaluation.
dict where keys are the names of submetrics and values are the values of
raise
NotImplementedError
(
'Evaluation not implemented'
)
the metric for that one document
def
higher_is_better
(
self
):
:param doc:
"""
The document as returned from training_docs, validation_docs, or test_docs.
:returns: {str: bool}
:param results:
A dictionary where keys are the names of submetrics and values are
The results of the requests created in construct_requests.
whether a higher value of the submetric is better
"""
"""
# TODO: implement evaluation.
# TODO: implement evaluation.
raise
NotImplementedError
(
"Evaluation not implemented"
)
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
aggregation
(
self
):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
"Evaluation not implemented"
)
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
"Evaluation not implemented"
)
lm_eval/tasks/race.py
View file @
1f8a8c1d
...
@@ -20,7 +20,7 @@ _CITATION = """
...
@@ -20,7 +20,7 @@ _CITATION = """
@article{lai2017large,
@article{lai2017large,
title={RACE: Large-scale ReAding Comprehension Dataset From Examinations},
title={RACE: Large-scale ReAding Comprehension Dataset From Examinations},
author={Lai, Guokun and Xie, Qizhe and Liu, Hanxiao and Yang, Yiming and Hovy, Eduard},
author={Lai, Guokun and Xie, Qizhe and Liu, Hanxiao and Yang, Yiming and Hovy, Eduard},
journal={arXiv preprint arXiv:1704.04683},
journal={arXiv preprint arXiv:1704.04683},
year={2017}
year={2017}
}
}
"""
"""
...
@@ -40,7 +40,7 @@ class RACE(Task):
...
@@ -40,7 +40,7 @@ class RACE(Task):
DATASET_NAME
=
"high"
DATASET_NAME
=
"high"
cache
=
{}
cache
=
{}
letter_to_num
=
{
'A'
:
0
,
'B'
:
1
,
'C'
:
2
,
'D'
:
3
}
letter_to_num
=
{
"A"
:
0
,
"B"
:
1
,
"C"
:
2
,
"D"
:
3
}
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
return
True
return
True
...
@@ -59,17 +59,27 @@ class RACE(Task):
...
@@ -59,17 +59,27 @@ class RACE(Task):
# is shown that one document is made per passage.
# is shown that one document is made per passage.
r
=
collections
.
defaultdict
(
list
)
r
=
collections
.
defaultdict
(
list
)
for
item
in
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)[
set
]:
for
item
in
datasets
.
load_dataset
(
r
[
item
[
'article'
]].
append
(
item
)
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)[
set
]:
res
=
list
(
r
.
values
()
>>
each
(
lambda
x
:
{
r
[
item
[
"article"
]].
append
(
item
)
'article'
:
x
[
0
][
'article'
],
'problems'
:
x
>>
each
(
lambda
y
:
{
res
=
list
(
'question'
:
y
[
'question'
],
r
.
values
()
'answer'
:
y
[
'answer'
],
>>
each
(
'options'
:
y
[
'options'
],
lambda
x
:
{
})
"article"
:
x
[
0
][
"article"
],
}))
"problems"
:
x
>>
each
(
lambda
y
:
{
"question"
:
y
[
"question"
],
"answer"
:
y
[
"answer"
],
"options"
:
y
[
"options"
],
}
),
}
)
)
self
.
cache
[
set
]
=
res
self
.
cache
[
set
]
=
res
return
res
return
res
...
@@ -85,49 +95,56 @@ class RACE(Task):
...
@@ -85,49 +95,56 @@ class RACE(Task):
@
classmethod
@
classmethod
def
get_answer_option
(
cls
,
problem
):
def
get_answer_option
(
cls
,
problem
):
answer
=
cls
.
letter_to_num
[
problem
[
'
answer
'
]]
answer
=
cls
.
letter_to_num
[
problem
[
"
answer
"
]]
return
problem
[
'
options
'
][
answer
]
return
problem
[
"
options
"
][
answer
]
@
classmethod
@
classmethod
def
last_problem
(
cls
,
doc
):
def
last_problem
(
cls
,
doc
):
return
doc
[
'
problems
'
][
-
1
]
return
doc
[
"
problems
"
][
-
1
]
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
text
=
'Article: '
+
doc
[
'article'
]
+
'
\n\n
'
text
=
"Article: "
+
doc
[
"article"
]
+
"
\n\n
"
for
problem
in
doc
[
'problems'
][:
-
1
]:
for
problem
in
doc
[
"problems"
][:
-
1
]:
if
problem
[
'question'
][
-
6
:]
==
' _ .'
:
if
problem
[
"question"
][
-
6
:]
==
" _ ."
:
text
+=
problem
[
'question'
][
-
5
:]
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
text
+=
(
problem
[
"question"
][
-
5
:]
+
self
.
get_answer_option
(
problem
)
+
"
\n
"
)
else
:
else
:
question
=
'
Question:
'
+
problem
[
'
question
'
]
+
'
\n
'
question
=
"
Question:
"
+
problem
[
"
question
"
]
+
"
\n
"
answer
=
'
Answer:
'
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
answer
=
"
Answer:
"
+
self
.
get_answer_option
(
problem
)
+
"
\n
"
text
+=
question
+
answer
text
+=
question
+
answer
text
+=
self
.
last_problem
(
doc
)[
'
question
'
]
text
+=
self
.
last_problem
(
doc
)[
"
question
"
]
return
text
return
text
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"article"
]
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
" "
+
self
.
get_answer_option
(
self
.
last_problem
(
doc
))
return
" "
+
self
.
get_answer_option
(
self
.
last_problem
(
doc
))
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
:param doc:
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
"""
problem
=
self
.
last_problem
(
doc
)
problem
=
self
.
last_problem
(
doc
)
ll_choices
=
[
ll_choices
=
[
rf
.
loglikelihood
(
ctx
,
" "
+
problem
[
'options'
][
i
])[
0
]
rf
.
loglikelihood
(
ctx
,
" "
+
problem
[
"options"
][
i
])[
0
]
for
i
in
range
(
4
)
for
i
in
range
(
4
)
]
]
return
ll_choices
return
ll_choices
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
dict where keys are the names of submetrics and values are the values of
the metric for that one document
the metric for that one document
:param doc:
:param doc:
...
@@ -135,28 +152,22 @@ class RACE(Task):
...
@@ -135,28 +152,22 @@ class RACE(Task):
:param results:
:param results:
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
gold
=
self
.
letter_to_num
[
self
.
last_problem
(
doc
)[
'
answer
'
]]
gold
=
self
.
letter_to_num
[
self
.
last_problem
(
doc
)[
"
answer
"
]]
pred
=
np
.
argmax
(
results
)
pred
=
np
.
argmax
(
results
)
return
{
return
{
"acc"
:
int
(
pred
==
gold
)}
"acc"
:
int
(
pred
==
gold
)
}
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
"""
:returns: {str: [float] -> float}
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
functions that aggregate a list of metrics
"""
"""
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
:returns: {str: bool}
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
whether a higher value of the submetric is better
"""
"""
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
lm_eval/tasks/sat.py
View file @
1f8a8c1d
...
@@ -59,11 +59,19 @@ class SATAnalogies(MultipleChoiceTask):
...
@@ -59,11 +59,19 @@ class SATAnalogies(MultipleChoiceTask):
def
_process_doc
(
self
,
doc
):
def
_process_doc
(
self
,
doc
):
return
{
return
{
'source'
:
doc
[
'source'
],
"source"
:
doc
[
"source"
],
'query'
:
doc
[
'stem'
].
split
(
' '
)[:
2
],
"query"
:
doc
[
"stem"
].
split
(
" "
)[:
2
],
'choices'
:
[
"{} is to {}"
.
format
(
*
c
.
split
(
' '
)[:
2
])
for
c
in
doc
[
"choices"
]],
"choices"
:
[
'gold'
:
[
'a'
,
'b'
,
'c'
,
'd'
,
'e'
].
index
(
doc
[
'solution'
].
strip
()),
"{} is to {}"
.
format
(
*
c
.
split
(
" "
)[:
2
])
for
c
in
doc
[
"choices"
]
],
"gold"
:
[
"a"
,
"b"
,
"c"
,
"d"
,
"e"
].
index
(
doc
[
"solution"
].
strip
()),
}
}
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
"{} is to {} as"
.
format
(
*
doc
[
'query'
])
return
"{} is to {} as"
.
format
(
*
doc
[
"query"
])
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"source"
]
+
"
\n
"
+
" "
.
join
(
doc
[
"query"
])
lm_eval/tasks/sciq.py
View file @
1f8a8c1d
...
@@ -54,10 +54,10 @@ class SciQ(MultipleChoiceTask):
...
@@ -54,10 +54,10 @@ class SciQ(MultipleChoiceTask):
doc
[
"distractor3"
],
doc
[
"distractor3"
],
doc
[
"correct_answer"
],
doc
[
"correct_answer"
],
]
]
src
=
doc
[
'
support
'
]
src
=
doc
[
"
support
"
]
out_doc
=
{
out_doc
=
{
"source"
:
src
,
"source"
:
src
,
"query"
:
doc
[
'
question
'
],
"query"
:
doc
[
"
question
"
],
"choices"
:
choices
,
"choices"
:
choices
,
"gold"
:
3
,
"gold"
:
3
,
}
}
...
@@ -65,3 +65,9 @@ class SciQ(MultipleChoiceTask):
...
@@ -65,3 +65,9 @@ class SciQ(MultipleChoiceTask):
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: {}
\n
Answer:"
.
format
(
doc
[
"source"
],
doc
[
"query"
]).
strip
()
return
"{}
\n
Question: {}
\n
Answer:"
.
format
(
doc
[
"source"
],
doc
[
"query"
]).
strip
()
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"source"
]
+
" "
+
doc
[
"query"
]
lm_eval/tasks/squad.py
View file @
1f8a8c1d
"""
"""
Know What You Don’t Know: Unanswerable Questions for SQuAD
Know What You Don’t Know: Unanswerable Questions for SQuAD
https://arxiv.org/pdf/1806.03822.pdf
https://arxiv.org/pdf/1806.03822.pdf
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset,
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset,
consisting of questions posed by crowdworkers on a set of Wikipedia articles,
consisting of questions posed by crowdworkers on a set of Wikipedia articles,
where the answer to every question is a segment of text, or span, from the
where the answer to every question is a segment of text, or span, from the
corresponding reading passage, or the question might be unanswerable.
corresponding reading passage, or the question might be unanswerable.
SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable
SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable
questions written adversarially by crowdworkers to look similar to answerable ones.
questions written adversarially by crowdworkers to look similar to answerable ones.
To do well on SQuAD2.0, systems must not only answer questions when possible, but
To do well on SQuAD2.0, systems must not only answer questions when possible, but
also determine when no answer is supported by the paragraph and abstain from answering.
also determine when no answer is supported by the paragraph and abstain from answering.
Homepage: https://rajpurkar.github.io/SQuAD-explorer/
Homepage: https://rajpurkar.github.io/SQuAD-explorer/
"""
"""
import
datasets
import
datasets
from
math
import
exp
from
math
import
exp
from
lm_eval.base
import
rf
,
Task
from
lm_eval.base
import
rf
,
Task
from
functools
import
partial
from
functools
import
partial
from
packaging
import
version
from
packaging
import
version
_CITATION
=
"""
_CITATION
=
"""
@misc{rajpurkar2018know,
@misc{rajpurkar2018know,
title={Know What You Don't Know: Unanswerable Questions for SQuAD},
title={Know What You Don't Know: Unanswerable Questions for SQuAD},
author={Pranav Rajpurkar and Robin Jia and Percy Liang},
author={Pranav Rajpurkar and Robin Jia and Percy Liang},
year={2018},
year={2018},
eprint={1806.03822},
eprint={1806.03822},
archivePrefix={arXiv},
archivePrefix={arXiv},
primaryClass={cs.CL}
primaryClass={cs.CL}
}
}
"""
"""
def
_squad_metric
(
predictions
,
references
):
def
_squad_metric
(
predictions
,
references
):
squad_metric
=
datasets
.
load_metric
(
"squad_v2"
)
squad_metric
=
datasets
.
load_metric
(
"squad_v2"
)
return
squad_metric
.
compute
(
predictions
=
predictions
,
references
=
references
)
return
squad_metric
.
compute
(
predictions
=
predictions
,
references
=
references
)
def
_squad_agg
(
key
,
items
):
def
_squad_agg
(
key
,
items
):
predictions
,
references
=
zip
(
*
items
)
predictions
,
references
=
zip
(
*
items
)
return
_squad_metric
(
predictions
=
predictions
,
references
=
references
)[
key
]
return
_squad_metric
(
predictions
=
predictions
,
references
=
references
)[
key
]
class
SQuAD2
(
Task
):
class
SQuAD2
(
Task
):
VERSION
=
1
VERSION
=
1
DATASET_PATH
=
"squad_v2"
DATASET_PATH
=
"squad_v2"
DATASET_NAME
=
None
DATASET_NAME
=
None
# HF changed squad on us so we have to make sure we aren't running the old one
# HF changed squad on us so we have to make sure we aren't running the old one
assert
version
.
parse
(
datasets
.
__version__
)
>=
version
.
parse
(
"1.11.0"
),
"datasets v1.11.0 or later required for SQuAD"
assert
version
.
parse
(
datasets
.
__version__
)
>=
version
.
parse
(
"1.11.0"
def
has_training_docs
(
self
):
),
"datasets v1.11.0 or later required for SQuAD"
return
True
def
has_training_docs
(
self
):
def
has_validation_docs
(
self
):
return
True
return
True
def
has_validation_docs
(
self
):
def
has_test_docs
(
self
):
return
True
return
False
def
has_test_docs
(
self
):
def
training_docs
(
self
):
return
False
return
self
.
dataset
[
"train"
]
def
training_docs
(
self
):
def
validation_docs
(
self
):
return
self
.
dataset
[
"train"
]
return
self
.
dataset
[
"validation"
]
def
validation_docs
(
self
):
def
doc_to_text
(
self
,
doc
):
return
self
.
dataset
[
"validation"
]
return
'Title: '
+
doc
[
'title'
]
+
'
\n\n
'
+
'Background: '
+
doc
[
'context'
]
+
'
\n\n
'
+
'Question: '
+
doc
[
'question'
]
+
'
\n\n
'
+
'Answer:'
def
doc_to_text
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
(
answer_list
=
doc
[
'answers'
][
'text'
]
"Title: "
if
len
(
answer_list
)
>
0
:
+
doc
[
"title"
]
answer
=
answer_list
[
0
]
+
"
\n\n
"
else
:
+
"Background: "
answer
=
'unanswerable'
+
doc
[
"context"
]
return
" "
+
answer
+
"
\n\n
"
+
"Question: "
def
construct_requests
(
self
,
doc
,
ctx
):
+
doc
[
"question"
]
""" Uses RequestFactory to construct Requests and returns an iterable of
+
"
\n\n
"
Requests which will be sent to the LM.
+
"Answer:"
)
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
def
should_decontaminate
(
self
):
:param ctx: str
return
True
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
def
doc_to_decontamination_query
(
self
,
doc
):
part of the document for `doc`.
return
doc
[
"context"
]
"""
continuation
=
rf
.
greedy_until
(
ctx
,
[
'
\n
'
])
def
doc_to_target
(
self
,
doc
):
is_unanswerable
=
rf
.
loglikelihood
(
ctx
,
" "
+
"unanswerable"
)
answer_list
=
doc
[
"answers"
][
"text"
]
return
continuation
,
is_unanswerable
if
len
(
answer_list
)
>
0
:
answer
=
answer_list
[
0
]
def
process_results
(
self
,
doc
,
results
):
else
:
"""Take a single document and the LM results and evaluates, returning a
answer
=
"unanswerable"
dict where keys are the names of submetrics and values are the values of
return
" "
+
answer
the metric for that one document
def
construct_requests
(
self
,
doc
,
ctx
):
:param doc:
"""Uses RequestFactory to construct Requests and returns an iterable of
The document as returned from training_docs, validation_docs, or test_docs.
Requests which will be sent to the LM.
:param results:
The results of the requests created in construct_requests.
:param doc:
"""
The document as returned from training_docs, validation_docs, or test_docs.
continuation
,
(
logprob_unanswerable
,
_
)
=
results
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
no_answer_probability
=
exp
(
logprob_unanswerable
)
language description, as well as the few shot examples, and the question
part of the document for `doc`.
predictions
=
{
"""
'id'
:
doc
[
'id'
],
continuation
=
rf
.
greedy_until
(
ctx
,
[
"
\n
"
])
'prediction_text'
:
continuation
,
is_unanswerable
=
rf
.
loglikelihood
(
ctx
,
" "
+
"unanswerable"
)
'no_answer_probability'
:
no_answer_probability
,
return
continuation
,
is_unanswerable
}
def
process_results
(
self
,
doc
,
results
):
references
=
{
"""Take a single document and the LM results and evaluates, returning a
'id'
:
doc
[
'id'
],
dict where keys are the names of submetrics and values are the values of
'answers'
:
doc
[
'answers'
],
the metric for that one document
}
:param doc:
return
{
The document as returned from training_docs, validation_docs, or test_docs.
'exact'
:
(
predictions
,
references
),
# Exact match (the normalized answer exactly match the gold answer)
:param results:
'f1'
:
(
predictions
,
references
),
# The F-score of predicted tokens versus the gold answer
The results of the requests created in construct_requests.
'HasAns_exact'
:
(
predictions
,
references
),
# Exact match (the normalized answer exactly match the gold answer)
"""
'HasAns_f1'
:
(
predictions
,
references
),
# The F-score of predicted tokens versus the gold answer
continuation
,
(
logprob_unanswerable
,
_
)
=
results
'NoAns_exact'
:
(
predictions
,
references
),
# Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1'
:
(
predictions
,
references
),
# The F-score of predicted tokens versus the gold answer
no_answer_probability
=
exp
(
logprob_unanswerable
)
'best_exact'
:
(
predictions
,
references
),
# Best exact match (with varying threshold)
'best_f1'
:
(
predictions
,
references
),
# Best F1 (with varying threshold)
predictions
=
{
}
"id"
:
doc
[
"id"
],
"prediction_text"
:
continuation
,
def
aggregation
(
self
):
"no_answer_probability"
:
no_answer_probability
,
"""
}
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
references
=
{
functions that aggregate a list of metrics
"id"
:
doc
[
"id"
],
"""
"answers"
:
doc
[
"answers"
],
return
{
}
'exact'
:
partial
(
_squad_agg
,
'exact'
),
# Exact match (the normalized answer exactly match the gold answer)
'f1'
:
partial
(
_squad_agg
,
'f1'
),
# The F-score of predicted tokens versus the gold answer
return
{
'HasAns_exact'
:
partial
(
_squad_agg
,
'HasAns_exact'
),
# Exact match (the normalized answer exactly match the gold answer)
"exact"
:
(
'HasAns_f1'
:
partial
(
_squad_agg
,
'HasAns_f1'
),
# The F-score of predicted tokens versus the gold answer
predictions
,
'NoAns_exact'
:
partial
(
_squad_agg
,
'NoAns_exact'
),
# Exact match (the normalized answer exactly match the gold answer)
references
,
'NoAns_f1'
:
partial
(
_squad_agg
,
'NoAns_f1'
),
# The F-score of predicted tokens versus the gold answer
),
# Exact match (the normalized answer exactly match the gold answer)
'best_exact'
:
partial
(
_squad_agg
,
'best_exact'
),
# Best exact match (with varying threshold)
"f1"
:
(
'best_f1'
:
partial
(
_squad_agg
,
'best_f1'
),
# Best F1 (with varying threshold)
predictions
,
}
references
,
),
# The F-score of predicted tokens versus the gold answer
def
higher_is_better
(
self
):
"HasAns_exact"
:
(
"""
predictions
,
:returns: {str: bool}
references
,
A dictionary where keys are the names of submetrics and values are
),
# Exact match (the normalized answer exactly match the gold answer)
whether a higher value of the submetric is better
"HasAns_f1"
:
(
"""
predictions
,
return
{
references
,
'exact'
:
True
,
# Exact match (the normalized answer exactly match the gold answer)
),
# The F-score of predicted tokens versus the gold answer
'f1'
:
True
,
# The F-score of predicted tokens versus the gold answer
"NoAns_exact"
:
(
'HasAns_exact'
:
True
,
# Exact match (the normalized answer exactly match the gold answer)
predictions
,
'HasAns_f1'
:
True
,
# The F-score of predicted tokens versus the gold answer
references
,
'NoAns_exact'
:
True
,
# Exact match (the normalized answer exactly match the gold answer)
),
# Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1'
:
True
,
# The F-score of predicted tokens versus the gold answer
"NoAns_f1"
:
(
'best_exact'
:
True
,
# Best exact match (with varying threshold)
predictions
,
'best_f1'
:
True
,
# Best F1 (with varying threshold)
references
,
}
),
# The F-score of predicted tokens versus the gold answer
"best_exact"
:
(
predictions
,
references
,
),
# Best exact match (with varying threshold)
"best_f1"
:
(
predictions
,
references
),
# Best F1 (with varying threshold)
}
def
aggregation
(
self
):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return
{
"exact"
:
partial
(
_squad_agg
,
"exact"
),
# Exact match (the normalized answer exactly match the gold answer)
"f1"
:
partial
(
_squad_agg
,
"f1"
),
# The F-score of predicted tokens versus the gold answer
"HasAns_exact"
:
partial
(
_squad_agg
,
"HasAns_exact"
),
# Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1"
:
partial
(
_squad_agg
,
"HasAns_f1"
),
# The F-score of predicted tokens versus the gold answer
"NoAns_exact"
:
partial
(
_squad_agg
,
"NoAns_exact"
),
# Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1"
:
partial
(
_squad_agg
,
"NoAns_f1"
),
# The F-score of predicted tokens versus the gold answer
"best_exact"
:
partial
(
_squad_agg
,
"best_exact"
),
# Best exact match (with varying threshold)
"best_f1"
:
partial
(
_squad_agg
,
"best_f1"
),
# Best F1 (with varying threshold)
}
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"exact"
:
True
,
# Exact match (the normalized answer exactly match the gold answer)
"f1"
:
True
,
# The F-score of predicted tokens versus the gold answer
"HasAns_exact"
:
True
,
# Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1"
:
True
,
# The F-score of predicted tokens versus the gold answer
"NoAns_exact"
:
True
,
# Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1"
:
True
,
# The F-score of predicted tokens versus the gold answer
"best_exact"
:
True
,
# Best exact match (with varying threshold)
"best_f1"
:
True
,
# Best F1 (with varying threshold)
}
lm_eval/tasks/storycloze.py
View file @
1f8a8c1d
...
@@ -65,12 +65,27 @@ class StoryCloze(Task):
...
@@ -65,12 +65,27 @@ class StoryCloze(Task):
return
self
.
dataset
[
"test"
]
return
self
.
dataset
[
"test"
]
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
' '
.
join
([
return
" "
.
join
(
doc
[
"input_sentence_1"
],
[
doc
[
"input_sentence_2"
],
doc
[
"input_sentence_1"
],
doc
[
"input_sentence_3"
],
doc
[
"input_sentence_2"
],
doc
[
"input_sentence_4"
],
doc
[
"input_sentence_3"
],
])
doc
[
"input_sentence_4"
],
]
)
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
" "
.
join
(
[
doc
[
"input_sentence_1"
],
doc
[
"input_sentence_2"
],
doc
[
"input_sentence_3"
],
doc
[
"input_sentence_4"
],
]
)
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
clozes
=
[
doc
[
"sentence_quiz1"
],
doc
[
"sentence_quiz2"
]]
clozes
=
[
doc
[
"sentence_quiz1"
],
doc
[
"sentence_quiz2"
]]
...
@@ -78,7 +93,7 @@ class StoryCloze(Task):
...
@@ -78,7 +93,7 @@ class StoryCloze(Task):
return
" "
+
clozes
[
doc
[
"answer_right_ending"
]
-
1
]
return
" "
+
clozes
[
doc
[
"answer_right_ending"
]
-
1
]
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
:param doc:
:param doc:
...
@@ -89,10 +104,7 @@ class StoryCloze(Task):
...
@@ -89,10 +104,7 @@ class StoryCloze(Task):
part of the document for `doc`.
part of the document for `doc`.
"""
"""
clozes
=
[
doc
[
"sentence_quiz1"
],
doc
[
"sentence_quiz2"
]]
clozes
=
[
doc
[
"sentence_quiz1"
],
doc
[
"sentence_quiz2"
]]
lls
=
[
lls
=
[
rf
.
loglikelihood
(
ctx
,
" {}"
.
format
(
choice
))[
0
]
for
choice
in
clozes
]
rf
.
loglikelihood
(
ctx
,
" {}"
.
format
(
choice
))[
0
]
for
choice
in
clozes
]
return
lls
return
lls
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
...
@@ -106,10 +118,8 @@ class StoryCloze(Task):
...
@@ -106,10 +118,8 @@ class StoryCloze(Task):
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
gold
=
doc
[
"answer_right_ending"
]
-
1
gold
=
doc
[
"answer_right_ending"
]
-
1
acc
=
1.
if
np
.
argmax
(
results
)
==
gold
else
0.
acc
=
1.0
if
np
.
argmax
(
results
)
==
gold
else
0.0
return
{
return
{
"acc"
:
acc
}
"acc"
:
acc
}
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
"""
...
@@ -117,9 +127,7 @@ class StoryCloze(Task):
...
@@ -117,9 +127,7 @@ class StoryCloze(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
functions that aggregate a list of metrics
"""
"""
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
...
@@ -127,9 +135,7 @@ class StoryCloze(Task):
...
@@ -127,9 +135,7 @@ class StoryCloze(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
whether a higher value of the submetric is better
"""
"""
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
class
StoryCloze2016
(
StoryCloze
):
class
StoryCloze2016
(
StoryCloze
):
...
...
lm_eval/tasks/superglue.py
View file @
1f8a8c1d
...
@@ -56,14 +56,20 @@ class BoolQ(Task):
...
@@ -56,14 +56,20 @@ class BoolQ(Task):
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
f
"
{
doc
[
'passage'
]
}
\n
Question:
{
doc
[
'question'
]
}
?
\n
Answer:"
return
f
"
{
doc
[
'passage'
]
}
\n
Question:
{
doc
[
'question'
]
}
?
\n
Answer:"
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"passage"
]
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
" "
+
yesno
(
doc
[
'
label
'
])
return
" "
+
yesno
(
doc
[
"
label
"
])
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
'
yes
'
)
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
"
yes
"
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
'
no
'
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
"
no
"
)
return
ll_yes
,
ll_no
return
ll_yes
,
ll_no
...
@@ -71,21 +77,15 @@ class BoolQ(Task):
...
@@ -71,21 +77,15 @@ class BoolQ(Task):
ll_yes
,
ll_no
=
results
ll_yes
,
ll_no
=
results
gold
=
doc
[
"label"
]
gold
=
doc
[
"label"
]
acc
=
1.
if
(
ll_yes
>
ll_no
)
==
gold
else
0.
acc
=
1.0
if
(
ll_yes
>
ll_no
)
==
gold
else
0.0
return
{
"acc"
:
acc
}
return
{
"acc"
:
acc
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
class
CommitmentBank
(
Task
):
class
CommitmentBank
(
Task
):
...
@@ -123,27 +123,21 @@ class CommitmentBank(Task):
...
@@ -123,27 +123,21 @@ class CommitmentBank(Task):
return
" {}"
.
format
({
0
:
"True"
,
1
:
"False"
,
2
:
"Neither"
}[
doc
[
"label"
]])
return
" {}"
.
format
({
0
:
"True"
,
1
:
"False"
,
2
:
"Neither"
}[
doc
[
"label"
]])
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
ll_true
,
_
=
rf
.
loglikelihood
(
ctx
,
'
True
'
)
ll_true
,
_
=
rf
.
loglikelihood
(
ctx
,
"
True
"
)
ll_false
,
_
=
rf
.
loglikelihood
(
ctx
,
'
False
'
)
ll_false
,
_
=
rf
.
loglikelihood
(
ctx
,
"
False
"
)
ll_neither
,
_
=
rf
.
loglikelihood
(
ctx
,
'
Neither
'
)
ll_neither
,
_
=
rf
.
loglikelihood
(
ctx
,
"
Neither
"
)
return
ll_true
,
ll_false
,
ll_neither
return
ll_true
,
ll_false
,
ll_neither
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
gold
=
doc
[
"label"
]
gold
=
doc
[
"label"
]
pred
=
np
.
argmax
(
results
)
pred
=
np
.
argmax
(
results
)
acc
=
1.
if
pred
==
gold
else
0.
acc
=
1.0
if
pred
==
gold
else
0.0
return
{
"acc"
:
acc
,
"f1"
:
(
pred
,
gold
)}
return
{
"acc"
:
acc
,
"f1"
:
(
pred
,
gold
)
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
,
"f1"
:
True
}
"acc"
:
True
,
"f1"
:
True
}
@
classmethod
@
classmethod
def
cb_multi_fi
(
cls
,
items
):
def
cb_multi_fi
(
cls
,
items
):
...
@@ -155,7 +149,7 @@ class CommitmentBank(Task):
...
@@ -155,7 +149,7 @@ class CommitmentBank(Task):
f13
=
sklearn
.
metrics
.
f1_score
(
y_true
=
golds
==
2
,
y_pred
=
preds
==
2
)
f13
=
sklearn
.
metrics
.
f1_score
(
y_true
=
golds
==
2
,
y_pred
=
preds
==
2
)
avg_f1
=
mean
([
f11
,
f12
,
f13
])
avg_f1
=
mean
([
f11
,
f12
,
f13
])
return
avg_f1
return
avg_f1
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
mean
,
"acc"
:
mean
,
...
@@ -201,7 +195,7 @@ class Copa(Task):
...
@@ -201,7 +195,7 @@ class Copa(Task):
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
choice1
=
" "
+
self
.
convert_choice
(
doc
[
"choice1"
])
choice1
=
" "
+
self
.
convert_choice
(
doc
[
"choice1"
])
choice2
=
" "
+
self
.
convert_choice
(
doc
[
"choice2"
])
choice2
=
" "
+
self
.
convert_choice
(
doc
[
"choice2"
])
ll_choice1
,
_
=
rf
.
loglikelihood
(
ctx
,
choice1
)
ll_choice1
,
_
=
rf
.
loglikelihood
(
ctx
,
choice1
)
ll_choice2
,
_
=
rf
.
loglikelihood
(
ctx
,
choice2
)
ll_choice2
,
_
=
rf
.
loglikelihood
(
ctx
,
choice2
)
...
@@ -210,21 +204,15 @@ class Copa(Task):
...
@@ -210,21 +204,15 @@ class Copa(Task):
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
gold
=
doc
[
"label"
]
gold
=
doc
[
"label"
]
pred
=
np
.
argmax
(
results
)
pred
=
np
.
argmax
(
results
)
acc
=
1.
if
pred
==
gold
else
0.
acc
=
1.0
if
pred
==
gold
else
0.0
return
{
"acc"
:
acc
}
return
{
"acc"
:
acc
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
@
staticmethod
@
staticmethod
def
convert_choice
(
choice
):
def
convert_choice
(
choice
):
...
@@ -267,28 +255,22 @@ class MultiRC(Task):
...
@@ -267,28 +255,22 @@ class MultiRC(Task):
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
true_choice
=
self
.
format_answer
(
answer
=
doc
[
"answer"
],
label
=
True
)
true_choice
=
self
.
format_answer
(
answer
=
doc
[
"answer"
],
label
=
True
)
false_choice
=
self
.
format_answer
(
answer
=
doc
[
"answer"
],
label
=
False
)
false_choice
=
self
.
format_answer
(
answer
=
doc
[
"answer"
],
label
=
False
)
ll_true_choice
,
_
=
rf
.
loglikelihood
(
ctx
,
f
'
{
true_choice
}
'
)
ll_true_choice
,
_
=
rf
.
loglikelihood
(
ctx
,
f
"
{
true_choice
}
"
)
ll_false_choice
,
_
=
rf
.
loglikelihood
(
ctx
,
f
'
{
false_choice
}
'
)
ll_false_choice
,
_
=
rf
.
loglikelihood
(
ctx
,
f
"
{
false_choice
}
"
)
return
ll_true_choice
,
ll_false_choice
return
ll_true_choice
,
ll_false_choice
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
ll_true_choice
,
ll_false_choice
=
results
ll_true_choice
,
ll_false_choice
=
results
pred
=
ll_true_choice
>
ll_false_choice
pred
=
ll_true_choice
>
ll_false_choice
return
{
return
{
"acc"
:
(
pred
,
doc
)}
"acc"
:
(
pred
,
doc
)
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
acc_all
}
"acc"
:
acc_all
}
class
ReCoRD
(
Task
):
class
ReCoRD
(
Task
):
...
@@ -337,7 +319,7 @@ class ReCoRD(Task):
...
@@ -337,7 +319,7 @@ class ReCoRD(Task):
@
classmethod
@
classmethod
def
format_answer
(
cls
,
query
,
entity
):
def
format_answer
(
cls
,
query
,
entity
):
return
f
'
-
{
query
}
'
.
replace
(
"@placeholder"
,
entity
)
return
f
"
-
{
query
}
"
.
replace
(
"@placeholder"
,
entity
)
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
# We only output the first correct entity in a doc
# We only output the first correct entity in a doc
...
@@ -359,8 +341,12 @@ class ReCoRD(Task):
...
@@ -359,8 +341,12 @@ class ReCoRD(Task):
prediction
=
doc
[
"entities"
][
max_idx
]
prediction
=
doc
[
"entities"
][
max_idx
]
gold_label_set
=
doc
[
"answers"
]
gold_label_set
=
doc
[
"answers"
]
f1
=
metric_max_over_ground_truths
(
squad_metrics
.
compute_f1
,
prediction
,
gold_label_set
)
f1
=
metric_max_over_ground_truths
(
em
=
metric_max_over_ground_truths
(
squad_metrics
.
compute_exact
,
prediction
,
gold_label_set
)
squad_metrics
.
compute_f1
,
prediction
,
gold_label_set
)
em
=
metric_max_over_ground_truths
(
squad_metrics
.
compute_exact
,
prediction
,
gold_label_set
)
return
{
return
{
"f1"
:
f1
,
"f1"
:
f1
,
...
@@ -403,19 +389,21 @@ class WordsInContext(Task):
...
@@ -403,19 +389,21 @@ class WordsInContext(Task):
return
self
.
dataset
[
"validation"
]
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
"Sentence 1: {}
\n
Sentence 2: {}
\n
Question: Is the word '{}' used in the same way in the"
\
return
(
" two sentences above?
\n
Answer:"
.
format
(
"Sentence 1: {}
\n
Sentence 2: {}
\n
Question: Is the word '{}' used in the same way in the"
doc
[
"sentence1"
],
" two sentences above?
\n
Answer:"
.
format
(
doc
[
"sentence2"
],
doc
[
"sentence1"
],
doc
[
"sentence1"
][
doc
[
"start1"
]:
doc
[
"end1"
]],
doc
[
"sentence2"
],
)
doc
[
"sentence1"
][
doc
[
"start1"
]
:
doc
[
"end1"
]],
)
)
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
" {}"
.
format
({
0
:
"no"
,
1
:
"yes"
}[
doc
[
"label"
]])
return
" {}"
.
format
({
0
:
"no"
,
1
:
"yes"
}[
doc
[
"label"
]])
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
'
yes
'
)
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
"
yes
"
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
'
no
'
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
"
no
"
)
return
ll_yes
,
ll_no
return
ll_yes
,
ll_no
...
@@ -423,21 +411,15 @@ class WordsInContext(Task):
...
@@ -423,21 +411,15 @@ class WordsInContext(Task):
ll_yes
,
ll_no
=
results
ll_yes
,
ll_no
=
results
gold
=
doc
[
"label"
]
gold
=
doc
[
"label"
]
acc
=
1.
if
(
ll_yes
>
ll_no
)
==
gold
else
0.
acc
=
1.
0
if
(
ll_yes
>
ll_no
)
==
gold
else
0.
0
return
{
return
{
"acc"
:
acc
}
"acc"
:
acc
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
class
SGWinogradSchemaChallenge
(
Task
):
class
SGWinogradSchemaChallenge
(
Task
):
...
@@ -461,9 +443,7 @@ class SGWinogradSchemaChallenge(Task):
...
@@ -461,9 +443,7 @@ class SGWinogradSchemaChallenge(Task):
if
self
.
_training_docs
is
None
:
if
self
.
_training_docs
is
None
:
# GPT-3 Paper's format only uses positive examples for fewshot "training"
# GPT-3 Paper's format only uses positive examples for fewshot "training"
self
.
_training_docs
=
[
self
.
_training_docs
=
[
doc
for
doc
in
doc
for
doc
in
self
.
dataset
[
"train"
]
if
doc
[
"label"
]
self
.
dataset
[
"train"
]
if
doc
[
"label"
]
]
]
return
self
.
_training_docs
return
self
.
_training_docs
...
@@ -473,25 +453,25 @@ class SGWinogradSchemaChallenge(Task):
...
@@ -473,25 +453,25 @@ class SGWinogradSchemaChallenge(Task):
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
raw_passage
=
doc
[
"text"
]
raw_passage
=
doc
[
"text"
]
# NOTE: HuggingFace span indices are word-based not character-based.
# NOTE: HuggingFace span indices are word-based not character-based.
pre
=
" "
.
join
(
raw_passage
.
split
()[:
doc
[
"span2_index"
]])
pre
=
" "
.
join
(
raw_passage
.
split
()[:
doc
[
"span2_index"
]])
post
=
raw_passage
[
len
(
pre
)
+
len
(
doc
[
"span2_text"
])
+
1
:]
post
=
raw_passage
[
len
(
pre
)
+
len
(
doc
[
"span2_text"
])
+
1
:]
passage
=
general_detokenize
(
pre
+
" *{}*"
.
format
(
doc
[
'
span2_text
'
])
+
post
)
passage
=
general_detokenize
(
pre
+
" *{}*"
.
format
(
doc
[
"
span2_text
"
])
+
post
)
noun
=
doc
[
"span1_text"
]
noun
=
doc
[
"span1_text"
]
pronoun
=
doc
[
"span2_text"
]
pronoun
=
doc
[
"span2_text"
]
text
=
(
text
=
(
f
"Passage:
{
passage
}
\n
"
f
"Passage:
{
passage
}
\n
"
+
f
"
Question: In the passage above, does the pronoun
\
"
*
{
pronoun
}
*
\
"
refer to
\
"
*
{
noun
}
*
\
"
?
\n
"
+
f
'
Question: In the passage above, does the pronoun "*
{
pronoun
}
*" refer to "*
{
noun
}
*"?
\n
'
+
"Answer:"
+
"Answer:"
)
)
return
text
return
text
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
" "
+
yesno
(
doc
[
'
label
'
])
return
" "
+
yesno
(
doc
[
"
label
"
])
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
'
yes
'
)
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
"
yes
"
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
'
no
'
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
"
no
"
)
return
ll_yes
,
ll_no
return
ll_yes
,
ll_no
...
@@ -499,18 +479,12 @@ class SGWinogradSchemaChallenge(Task):
...
@@ -499,18 +479,12 @@ class SGWinogradSchemaChallenge(Task):
ll_yes
,
ll_no
=
results
ll_yes
,
ll_no
=
results
gold
=
doc
[
"label"
]
gold
=
doc
[
"label"
]
acc
=
1.
if
(
ll_yes
>
ll_no
)
==
gold
else
0.
acc
=
1.
0
if
(
ll_yes
>
ll_no
)
==
gold
else
0.
0
return
{
return
{
"acc"
:
acc
}
"acc"
:
acc
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
lm_eval/tasks/translation.py
View file @
1f8a8c1d
...
@@ -41,44 +41,57 @@ def create_tasks_from_benchmarks(benchmark_dict):
...
@@ -41,44 +41,57 @@ def create_tasks_from_benchmarks(benchmark_dict):
:return: {task_name: task}
:return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
"""
"""
def
version_of
(
dataset
,
language_pair
):
def
version_of
(
dataset
,
language_pair
):
if
language_pair
[
-
2
:]
in
[
"zh"
,
"ja"
]:
if
language_pair
[
-
2
:]
in
[
"zh"
,
"ja"
]:
return
1
# changed to use jieba/nagisa
return
1
# changed to use jieba/nagisa
return
0
return
0
return
{
return
{
f
"
{
dataset
}
-
{
language_pair
}
"
:
create_translation_task
(
dataset
,
language_pair
,
version_of
(
dataset
,
language_pair
))
f
"
{
dataset
}
-
{
language_pair
}
"
:
create_translation_task
(
dataset
,
language_pair
,
version_of
(
dataset
,
language_pair
)
)
for
dataset
,
language_pairs
in
benchmark_dict
.
items
()
for
dataset
,
language_pairs
in
benchmark_dict
.
items
()
for
language_pair
in
language_pairs
for
language_pair
in
language_pairs
}
}
########################################
########################################
# Language Specifics
# Language Specifics
########################################
########################################
def
zh_split
(
zh_text
:
List
[
str
])
->
List
[
str
]:
def
zh_split
(
zh_text
:
List
[
str
])
->
List
[
str
]:
"""Chinese splitting"""
"""Chinese splitting"""
import
jieba
import
jieba
return
[
" "
.
join
(
jieba
.
cut
(
txt
.
strip
()))
for
txt
in
zh_text
]
return
[
" "
.
join
(
jieba
.
cut
(
txt
.
strip
()))
for
txt
in
zh_text
]
def
ja_split
(
ja_text
:
List
[
str
])
->
List
[
str
]:
def
ja_split
(
ja_text
:
List
[
str
])
->
List
[
str
]:
"""Japanese splitting"""
"""Japanese splitting"""
import
nagisa
import
nagisa
return
[
" "
.
join
(
nagisa
.
tagging
(
txt
.
strip
()).
words
)
for
txt
in
ja_text
]
return
[
" "
.
join
(
nagisa
.
tagging
(
txt
.
strip
()).
words
)
for
txt
in
ja_text
]
NO_SPACE_LANG
=
{
"zh"
:
zh_split
,
"ja"
:
ja_split
}
NO_SPACE_LANG
=
{
"zh"
:
zh_split
,
"ja"
:
ja_split
}
########################################
########################################
# Tasks
# Tasks
########################################
########################################
def
create_translation_task
(
dataset
,
language_pair
,
version
=
0
):
def
create_translation_task
(
dataset
,
language_pair
,
version
=
0
):
class
TranslationTask
(
GeneralTranslationTask
):
class
TranslationTask
(
GeneralTranslationTask
):
VERSION
=
version
VERSION
=
version
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
(
dataset
,
language_pair
)
super
().
__init__
(
dataset
,
language_pair
)
return
TranslationTask
return
TranslationTask
class
GeneralTranslationTask
(
Task
):
class
GeneralTranslationTask
(
Task
):
VERSION
=
0
VERSION
=
0
...
@@ -92,8 +105,9 @@ class GeneralTranslationTask(Task):
...
@@ -92,8 +105,9 @@ class GeneralTranslationTask(Task):
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
# This caches in the users home dir automatically
# This caches in the users home dir automatically
self
.
src_file
,
self
.
ref_file
=
\
self
.
src_file
,
self
.
ref_file
=
sacrebleu
.
download_test_set
(
sacrebleu
.
download_test_set
(
self
.
sacrebleu_dataset
,
self
.
sacrebleu_language_pair
)
self
.
sacrebleu_dataset
,
self
.
sacrebleu_language_pair
)
self
.
src_data
,
self
.
ref_data
=
[
self
.
src_data
,
self
.
ref_data
=
[
[
line
.
rstrip
()
for
line
in
sacrebleu
.
smart_open
(
file
)]
[
line
.
rstrip
()
for
line
in
sacrebleu
.
smart_open
(
file
)]
for
file
in
(
self
.
src_file
,
self
.
ref_file
)
for
file
in
(
self
.
src_file
,
self
.
ref_file
)
...
@@ -117,10 +131,9 @@ class GeneralTranslationTask(Task):
...
@@ -117,10 +131,9 @@ class GeneralTranslationTask(Task):
:return: Iterable[obj]
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
A iterable of any object, that doc_to_text can handle
"""
"""
return
[{
return
[
"src"
:
src
,
{
"src"
:
src
,
"ref"
:
ref
}
for
src
,
ref
in
zip
(
self
.
src_data
,
self
.
ref_data
)
"ref"
:
ref
]
}
for
src
,
ref
in
zip
(
self
.
src_data
,
self
.
ref_data
)]
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
language_codes
=
self
.
sacrebleu_language_pair
.
split
(
"-"
)
language_codes
=
self
.
sacrebleu_language_pair
.
split
(
"-"
)
...
@@ -128,12 +141,18 @@ class GeneralTranslationTask(Task):
...
@@ -128,12 +141,18 @@ class GeneralTranslationTask(Task):
tar_lang
=
code_to_language
(
language_codes
[
1
])
tar_lang
=
code_to_language
(
language_codes
[
1
])
return
f
"
{
src_lang
}
phrase: "
+
doc
[
"src"
]
+
f
"
\n
{
tar_lang
}
phrase:"
return
f
"
{
src_lang
}
phrase: "
+
doc
[
"src"
]
+
f
"
\n
{
tar_lang
}
phrase:"
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"src"
]
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
# This shows a single target, though there may be multiple targets in a lang test
# This shows a single target, though there may be multiple targets in a lang test
return
" "
+
doc
[
"ref"
]
if
isinstance
(
doc
[
"ref"
],
str
)
else
doc
[
"ref"
][
0
]
return
" "
+
doc
[
"ref"
]
if
isinstance
(
doc
[
"ref"
],
str
)
else
doc
[
"ref"
][
0
]
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
:param doc:
:param doc:
...
...
lm_eval/tasks/triviaqa.py
View file @
1f8a8c1d
...
@@ -43,10 +43,10 @@ class TriviaQA(Task):
...
@@ -43,10 +43,10 @@ class TriviaQA(Task):
return
False
return
False
def
training_docs
(
self
):
def
training_docs
(
self
):
return
self
.
dataset
[
'
train
'
]
return
self
.
dataset
[
"
train
"
]
def
validation_docs
(
self
):
def
validation_docs
(
self
):
return
self
.
dataset
[
'
validation
'
]
return
self
.
dataset
[
"
validation
"
]
def
test_docs
(
self
):
def
test_docs
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -54,8 +54,14 @@ class TriviaQA(Task):
...
@@ -54,8 +54,14 @@ class TriviaQA(Task):
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
f
"Question:
{
doc
[
'question'
]
}
\n
Answer:"
return
f
"Question:
{
doc
[
'question'
]
}
\n
Answer:"
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"question"
]
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
" "
+
doc
[
'
answer
'
][
'
value
'
]
return
" "
+
doc
[
"
answer
"
][
"
value
"
]
def
_remove_prefixes
(
self
,
aliases
):
def
_remove_prefixes
(
self
,
aliases
):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
...
@@ -69,15 +75,13 @@ class TriviaQA(Task):
...
@@ -69,15 +75,13 @@ class TriviaQA(Task):
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
ret
=
[]
ret
=
[]
for
alias
in
self
.
_remove_prefixes
(
doc
[
'
answer
'
][
'
aliases
'
]):
for
alias
in
self
.
_remove_prefixes
(
doc
[
"
answer
"
][
"
aliases
"
]):
_
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
" "
+
alias
)
_
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
" "
+
alias
)
ret
.
append
(
is_prediction
)
ret
.
append
(
is_prediction
)
return
ret
return
ret
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
return
{
return
{
"acc"
:
float
(
any
(
results
))}
"acc"
:
float
(
any
(
results
))
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
...
@@ -85,6 +89,4 @@ class TriviaQA(Task):
...
@@ -85,6 +89,4 @@ class TriviaQA(Task):
}
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
lm_eval/tasks/truthfulqa.py
View file @
1f8a8c1d
...
@@ -80,22 +80,29 @@ class TruthfulQAMultipleChoice(Task):
...
@@ -80,22 +80,29 @@ class TruthfulQAMultipleChoice(Task):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
QA_PROMPT
+
"
\n\n
Q: "
+
doc
[
'question'
]
+
"
\n
A:"
return
QA_PROMPT
+
"
\n\n
Q: "
+
doc
[
"question"
]
+
"
\n
A:"
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"question"
]
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
" "
return
" "
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
def
fewshot_context
(
assert
num_fewshot
==
0
,
"TruthfulQA is intended only for the zero-shot setting."
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
(
num_fewshot
==
0
),
"TruthfulQA is intended only for the zero-shot setting."
return
super
().
fewshot_context
(
return
super
().
fewshot_context
(
doc
=
doc
,
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
)
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
:param doc:
:param doc:
...
@@ -105,11 +112,15 @@ class TruthfulQAMultipleChoice(Task):
...
@@ -105,11 +112,15 @@ class TruthfulQAMultipleChoice(Task):
language description, as well as the few shot examples, and the question
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
"""
def
get_lls
(
targets
):
def
get_lls
(
targets
):
return
[
rf
.
loglikelihood
(
ctx
,
" "
+
t
)[
0
]
for
t
in
targets
]
return
[
rf
.
loglikelihood
(
ctx
,
" "
+
t
)[
0
]
for
t
in
targets
]
# MC1 and MC2 targets are not always the same set of strings so we collect
# MC1 and MC2 targets are not always the same set of strings so we collect
# likelihoods separately for simpler processing.
# likelihoods separately for simpler processing.
return
get_lls
(
doc
[
'mc1_targets'
][
"choices"
])
+
get_lls
(
doc
[
'mc2_targets'
][
"choices"
])
return
get_lls
(
doc
[
"mc1_targets"
][
"choices"
])
+
get_lls
(
doc
[
"mc2_targets"
][
"choices"
]
)
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
...
@@ -121,37 +132,29 @@ class TruthfulQAMultipleChoice(Task):
...
@@ -121,37 +132,29 @@ class TruthfulQAMultipleChoice(Task):
:param results:
:param results:
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
def
mc1
(
lls
):
def
mc1
(
lls
):
# The gold answers in `mc1_targets` are always first (index = `0`).
# The gold answers in `mc1_targets` are always first (index = `0`).
return
np
.
argmax
(
lls
)
==
0
return
np
.
argmax
(
lls
)
==
0
def
mc2
(
lls
):
def
mc2
(
lls
):
# Split on the first `0` as everything before it is true (`1`).
# Split on the first `0` as everything before it is true (`1`).
split_idx
=
list
(
doc
[
'
mc2_targets
'
][
"labels"
]).
index
(
0
)
split_idx
=
list
(
doc
[
"
mc2_targets
"
][
"labels"
]).
index
(
0
)
# Compute the normalized probability mass for the correct answer.
# Compute the normalized probability mass for the correct answer.
ll_true
,
ll_false
=
lls
[:
split_idx
],
lls
[
split_idx
:]
ll_true
,
ll_false
=
lls
[:
split_idx
],
lls
[
split_idx
:]
p_true
,
p_false
=
np
.
exp
(
np
.
array
(
ll_true
)),
np
.
exp
(
np
.
array
(
ll_false
))
p_true
,
p_false
=
np
.
exp
(
np
.
array
(
ll_true
)),
np
.
exp
(
np
.
array
(
ll_false
))
p_true
=
p_true
/
(
sum
(
p_true
)
+
sum
(
p_false
))
p_true
=
p_true
/
(
sum
(
p_true
)
+
sum
(
p_false
))
return
sum
(
p_true
)
return
sum
(
p_true
)
split_idx
=
len
(
doc
[
'
mc1_targets
'
][
"choices"
])
split_idx
=
len
(
doc
[
"
mc1_targets
"
][
"choices"
])
mc1_lls
,
mc2_lls
=
results
[:
split_idx
],
results
[
split_idx
:]
mc1_lls
,
mc2_lls
=
results
[:
split_idx
],
results
[
split_idx
:]
return
{
return
{
"mc1"
:
mc1
(
mc1_lls
),
"mc2"
:
mc2
(
mc2_lls
)}
"mc1"
:
mc1
(
mc1_lls
),
"mc2"
:
mc2
(
mc2_lls
)
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"mc1"
:
mean
,
"mc2"
:
mean
}
"mc1"
:
mean
,
"mc2"
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"mc1"
:
True
,
"mc2"
:
True
}
"mc1"
:
True
,
"mc2"
:
True
}
class
TruthfulQAGeneration
(
Task
):
class
TruthfulQAGeneration
(
Task
):
...
@@ -181,44 +184,45 @@ class TruthfulQAGeneration(Task):
...
@@ -181,44 +184,45 @@ class TruthfulQAGeneration(Task):
answer
=
answer
.
strip
()
answer
=
answer
.
strip
()
if
len
(
answer
):
if
len
(
answer
):
# Add a period after all answers.
# Add a period after all answers.
if
answer
[
-
1
]
!=
'.'
:
if
answer
[
-
1
]
!=
"."
:
formatted_answers
.
append
(
answer
+
'.'
)
formatted_answers
.
append
(
answer
+
"."
)
else
:
else
:
formatted_answers
.
append
(
answer
)
formatted_answers
.
append
(
answer
)
return
formatted_answers
return
formatted_answers
def
validation_docs
(
self
):
def
validation_docs
(
self
):
for
doc
in
self
.
dataset
[
"validation"
]:
for
doc
in
self
.
dataset
[
"validation"
]:
incorrect_answers
=
self
.
_format_answers
(
doc
[
'
incorrect_answers
'
])
incorrect_answers
=
self
.
_format_answers
(
doc
[
"
incorrect_answers
"
])
correct_answers
=
self
.
_format_answers
(
doc
[
'
correct_answers
'
])
correct_answers
=
self
.
_format_answers
(
doc
[
"
correct_answers
"
])
if
"I have no comment."
not
in
correct_answers
:
if
"I have no comment."
not
in
correct_answers
:
correct_answers
.
append
(
"I have no comment."
)
correct_answers
.
append
(
"I have no comment."
)
yield
{
yield
{
'
question
'
:
doc
[
'
question
'
].
strip
(),
"
question
"
:
doc
[
"
question
"
].
strip
(),
'
correct_answers
'
:
correct_answers
,
"
correct_answers
"
:
correct_answers
,
'
incorrect_answers
'
:
incorrect_answers
"
incorrect_answers
"
:
incorrect_answers
,
}
}
def
test_docs
(
self
):
def
test_docs
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
QA_PROMPT
+
"
\n\n
Q: "
+
doc
[
'
question
'
]
return
QA_PROMPT
+
"
\n\n
Q: "
+
doc
[
"
question
"
]
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
" "
return
" "
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
def
fewshot_context
(
assert
num_fewshot
==
0
,
"TruthfulQA is intended only for the zero-shot setting."
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
(
num_fewshot
==
0
),
"TruthfulQA is intended only for the zero-shot setting."
return
super
().
fewshot_context
(
return
super
().
fewshot_context
(
doc
=
doc
,
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
)
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
:param doc:
:param doc:
...
@@ -229,7 +233,7 @@ class TruthfulQAGeneration(Task):
...
@@ -229,7 +233,7 @@ class TruthfulQAGeneration(Task):
part of the document for `doc`.
part of the document for `doc`.
"""
"""
# TODO: Find a way to cap the number of generated tokens to `50` as in the official implementation.
# TODO: Find a way to cap the number of generated tokens to `50` as in the official implementation.
completion
=
rf
.
greedy_until
(
ctx
,
[
'.'
])
completion
=
rf
.
greedy_until
(
ctx
,
[
"."
])
return
completion
return
completion
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
...
@@ -243,18 +247,18 @@ class TruthfulQAGeneration(Task):
...
@@ -243,18 +247,18 @@ class TruthfulQAGeneration(Task):
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
completion
=
results
[
0
].
strip
()
completion
=
results
[
0
].
strip
()
true_refs
,
false_refs
=
doc
[
'
correct_answers
'
],
doc
[
'
incorrect_answers
'
]
true_refs
,
false_refs
=
doc
[
"
correct_answers
"
],
doc
[
"
incorrect_answers
"
]
all_refs
=
true_refs
+
false_refs
all_refs
=
true_refs
+
false_refs
# Process the sentence-level BLEURT, BLEU, and ROUGE for similarity measures.
# Process the sentence-level BLEURT, BLEU, and ROUGE for similarity measures.
# BLEURT
# BLEURT
bleurt_scores_true
=
self
.
bleurt
.
compute
(
bleurt_scores_true
=
self
.
bleurt
.
compute
(
predictions
=
[
completion
]
*
len
(
true_refs
),
predictions
=
[
completion
]
*
len
(
true_refs
),
references
=
true_refs
references
=
true_refs
)[
'
scores
'
]
)[
"
scores
"
]
bleurt_scores_false
=
self
.
bleurt
.
compute
(
bleurt_scores_false
=
self
.
bleurt
.
compute
(
predictions
=
[
completion
]
*
len
(
false_refs
),
predictions
=
[
completion
]
*
len
(
false_refs
),
references
=
false_refs
references
=
false_refs
)[
'
scores
'
]
)[
"
scores
"
]
bleurt_correct
=
max
(
bleurt_scores_true
)
bleurt_correct
=
max
(
bleurt_scores_true
)
bleurt_incorrect
=
max
(
bleurt_scores_false
)
bleurt_incorrect
=
max
(
bleurt_scores_false
)
bleurt_max
=
bleurt_correct
bleurt_max
=
bleurt_correct
...
@@ -263,8 +267,8 @@ class TruthfulQAGeneration(Task):
...
@@ -263,8 +267,8 @@ class TruthfulQAGeneration(Task):
# BLEU
# BLEU
bleu_scores
=
[
self
.
bleu
([[
ref
]],
[
completion
])
for
ref
in
all_refs
]
bleu_scores
=
[
self
.
bleu
([[
ref
]],
[
completion
])
for
ref
in
all_refs
]
bleu_correct
=
np
.
nanmax
(
bleu_scores
[:
len
(
true_refs
)])
bleu_correct
=
np
.
nanmax
(
bleu_scores
[:
len
(
true_refs
)])
bleu_incorrect
=
np
.
nanmax
(
bleu_scores
[
len
(
true_refs
):])
bleu_incorrect
=
np
.
nanmax
(
bleu_scores
[
len
(
true_refs
)
:])
bleu_max
=
bleu_correct
bleu_max
=
bleu_correct
bleu_diff
=
bleu_correct
-
bleu_incorrect
bleu_diff
=
bleu_correct
-
bleu_incorrect
bleu_acc
=
int
(
bleu_correct
>
bleu_incorrect
)
bleu_acc
=
int
(
bleu_correct
>
bleu_incorrect
)
...
@@ -272,23 +276,23 @@ class TruthfulQAGeneration(Task):
...
@@ -272,23 +276,23 @@ class TruthfulQAGeneration(Task):
# ROUGE-N
# ROUGE-N
rouge_scores
=
[
self
.
rouge
([
ref
],
[
completion
])
for
ref
in
all_refs
]
rouge_scores
=
[
self
.
rouge
([
ref
],
[
completion
])
for
ref
in
all_refs
]
# ROUGE-1
# ROUGE-1
rouge1_scores
=
[
score
[
'
rouge1
'
]
for
score
in
rouge_scores
]
rouge1_scores
=
[
score
[
"
rouge1
"
]
for
score
in
rouge_scores
]
rouge1_correct
=
np
.
nanmax
(
rouge1_scores
[:
len
(
true_refs
)])
rouge1_correct
=
np
.
nanmax
(
rouge1_scores
[:
len
(
true_refs
)])
rouge1_incorrect
=
np
.
nanmax
(
rouge1_scores
[
len
(
true_refs
):])
rouge1_incorrect
=
np
.
nanmax
(
rouge1_scores
[
len
(
true_refs
)
:])
rouge1_max
=
rouge1_correct
rouge1_max
=
rouge1_correct
rouge1_diff
=
rouge1_correct
-
rouge1_incorrect
rouge1_diff
=
rouge1_correct
-
rouge1_incorrect
rouge1_acc
=
int
(
rouge1_correct
>
rouge1_incorrect
)
rouge1_acc
=
int
(
rouge1_correct
>
rouge1_incorrect
)
# ROUGE-2
# ROUGE-2
rouge2_scores
=
[
score
[
'
rouge2
'
]
for
score
in
rouge_scores
]
rouge2_scores
=
[
score
[
"
rouge2
"
]
for
score
in
rouge_scores
]
rouge2_correct
=
np
.
nanmax
(
rouge2_scores
[:
len
(
true_refs
)])
rouge2_correct
=
np
.
nanmax
(
rouge2_scores
[:
len
(
true_refs
)])
rouge2_incorrect
=
np
.
nanmax
(
rouge2_scores
[
len
(
true_refs
):])
rouge2_incorrect
=
np
.
nanmax
(
rouge2_scores
[
len
(
true_refs
)
:])
rouge2_max
=
rouge2_correct
rouge2_max
=
rouge2_correct
rouge2_diff
=
rouge2_correct
-
rouge2_incorrect
rouge2_diff
=
rouge2_correct
-
rouge2_incorrect
rouge2_acc
=
int
(
rouge2_correct
>
rouge2_incorrect
)
rouge2_acc
=
int
(
rouge2_correct
>
rouge2_incorrect
)
# ROUGE-L
# ROUGE-L
rougeL_scores
=
[
score
[
'
rougeLsum
'
]
for
score
in
rouge_scores
]
rougeL_scores
=
[
score
[
"
rougeLsum
"
]
for
score
in
rouge_scores
]
rougeL_correct
=
np
.
nanmax
(
rougeL_scores
[:
len
(
true_refs
)])
rougeL_correct
=
np
.
nanmax
(
rougeL_scores
[:
len
(
true_refs
)])
rougeL_incorrect
=
np
.
nanmax
(
rougeL_scores
[
len
(
true_refs
):])
rougeL_incorrect
=
np
.
nanmax
(
rougeL_scores
[
len
(
true_refs
)
:])
rougeL_max
=
rougeL_correct
rougeL_max
=
rougeL_correct
rougeL_diff
=
rougeL_correct
-
rougeL_incorrect
rougeL_diff
=
rougeL_correct
-
rougeL_incorrect
rougeL_acc
=
int
(
rougeL_correct
>
rougeL_incorrect
)
rougeL_acc
=
int
(
rougeL_correct
>
rougeL_incorrect
)
...
@@ -297,19 +301,15 @@ class TruthfulQAGeneration(Task):
...
@@ -297,19 +301,15 @@ class TruthfulQAGeneration(Task):
"bleurt_max"
:
bleurt_max
,
"bleurt_max"
:
bleurt_max
,
"bleurt_acc"
:
bleurt_acc
,
"bleurt_acc"
:
bleurt_acc
,
"bleurt_diff"
:
bleurt_diff
,
"bleurt_diff"
:
bleurt_diff
,
"bleu_max"
:
bleu_max
,
"bleu_max"
:
bleu_max
,
"bleu_acc"
:
bleu_acc
,
"bleu_acc"
:
bleu_acc
,
"bleu_diff"
:
bleu_diff
,
"bleu_diff"
:
bleu_diff
,
"rouge1_max"
:
rouge1_max
,
"rouge1_max"
:
rouge1_max
,
"rouge1_acc"
:
rouge1_acc
,
"rouge1_acc"
:
rouge1_acc
,
"rouge1_diff"
:
rouge1_diff
,
"rouge1_diff"
:
rouge1_diff
,
"rouge2_max"
:
rouge2_max
,
"rouge2_max"
:
rouge2_max
,
"rouge2_acc"
:
rouge2_acc
,
"rouge2_acc"
:
rouge2_acc
,
"rouge2_diff"
:
rouge2_diff
,
"rouge2_diff"
:
rouge2_diff
,
"rougeL_max"
:
rougeL_max
,
"rougeL_max"
:
rougeL_max
,
"rougeL_acc"
:
rougeL_acc
,
"rougeL_acc"
:
rougeL_acc
,
"rougeL_diff"
:
rougeL_diff
,
"rougeL_diff"
:
rougeL_diff
,
...
@@ -320,19 +320,15 @@ class TruthfulQAGeneration(Task):
...
@@ -320,19 +320,15 @@ class TruthfulQAGeneration(Task):
"bleurt_max"
:
mean
,
"bleurt_max"
:
mean
,
"bleurt_acc"
:
mean
,
"bleurt_acc"
:
mean
,
"bleurt_diff"
:
mean
,
"bleurt_diff"
:
mean
,
"bleu_max"
:
mean
,
"bleu_max"
:
mean
,
"bleu_acc"
:
mean
,
"bleu_acc"
:
mean
,
"bleu_diff"
:
mean
,
"bleu_diff"
:
mean
,
"rouge1_max"
:
mean
,
"rouge1_max"
:
mean
,
"rouge1_acc"
:
mean
,
"rouge1_acc"
:
mean
,
"rouge1_diff"
:
mean
,
"rouge1_diff"
:
mean
,
"rouge2_max"
:
mean
,
"rouge2_max"
:
mean
,
"rouge2_acc"
:
mean
,
"rouge2_acc"
:
mean
,
"rouge2_diff"
:
mean
,
"rouge2_diff"
:
mean
,
"rougeL_max"
:
mean
,
"rougeL_max"
:
mean
,
"rougeL_acc"
:
mean
,
"rougeL_acc"
:
mean
,
"rougeL_diff"
:
mean
,
"rougeL_diff"
:
mean
,
...
@@ -343,19 +339,15 @@ class TruthfulQAGeneration(Task):
...
@@ -343,19 +339,15 @@ class TruthfulQAGeneration(Task):
"bleurt_max"
:
True
,
"bleurt_max"
:
True
,
"bleurt_acc"
:
True
,
"bleurt_acc"
:
True
,
"bleurt_diff"
:
True
,
"bleurt_diff"
:
True
,
"bleu_max"
:
True
,
"bleu_max"
:
True
,
"bleu_acc"
:
True
,
"bleu_acc"
:
True
,
"bleu_diff"
:
True
,
"bleu_diff"
:
True
,
"rouge1_max"
:
True
,
"rouge1_max"
:
True
,
"rouge1_acc"
:
True
,
"rouge1_acc"
:
True
,
"rouge1_diff"
:
True
,
"rouge1_diff"
:
True
,
"rouge2_max"
:
True
,
"rouge2_max"
:
True
,
"rouge2_acc"
:
True
,
"rouge2_acc"
:
True
,
"rouge2_diff"
:
True
,
"rouge2_diff"
:
True
,
"rougeL_max"
:
True
,
"rougeL_max"
:
True
,
"rougeL_acc"
:
True
,
"rougeL_acc"
:
True
,
"rougeL_diff"
:
True
,
"rougeL_diff"
:
True
,
...
@@ -379,7 +371,7 @@ class TruthfulQAGeneration(Task):
...
@@ -379,7 +371,7 @@ class TruthfulQAGeneration(Task):
force
=
False
,
force
=
False
,
lowercase
=
False
,
lowercase
=
False
,
tokenize
=
"intl"
,
tokenize
=
"intl"
,
use_effective_order
=
False
use_effective_order
=
False
,
).
score
).
score
return
score
return
score
...
@@ -396,9 +388,11 @@ class TruthfulQAGeneration(Task):
...
@@ -396,9 +388,11 @@ class TruthfulQAGeneration(Task):
rouge_types
=
[
"rouge1"
,
"rouge2"
,
"rougeLsum"
]
rouge_types
=
[
"rouge1"
,
"rouge2"
,
"rougeLsum"
]
scorer
=
rouge_scorer
.
RougeScorer
(
rouge_types
)
scorer
=
rouge_scorer
.
RougeScorer
(
rouge_types
)
# Add newlines between sentences to correctly compute `rougeLsum`.
# Add newlines between sentences to correctly compute `rougeLsum`.
def
_prepare_summary
(
summary
):
def
_prepare_summary
(
summary
):
summary
=
summary
.
replace
(
" . "
,
".
\n
"
)
summary
=
summary
.
replace
(
" . "
,
".
\n
"
)
return
summary
return
summary
# Accumulate confidence intervals.
# Accumulate confidence intervals.
aggregator
=
scoring
.
BootstrapAggregator
()
aggregator
=
scoring
.
BootstrapAggregator
()
for
ref
,
pred
in
zip
(
refs
,
preds
):
for
ref
,
pred
in
zip
(
refs
,
preds
):
...
@@ -406,4 +400,4 @@ class TruthfulQAGeneration(Task):
...
@@ -406,4 +400,4 @@ class TruthfulQAGeneration(Task):
pred
=
_prepare_summary
(
pred
)
pred
=
_prepare_summary
(
pred
)
aggregator
.
add_scores
(
scorer
.
score
(
ref
,
pred
))
aggregator
.
add_scores
(
scorer
.
score
(
ref
,
pred
))
result
=
aggregator
.
aggregate
()
result
=
aggregator
.
aggregate
()
return
{
type
:
result
[
type
].
mid
.
fmeasure
*
100
for
type
in
rouge_types
}
return
{
type
:
result
[
type
].
mid
.
fmeasure
*
100
for
type
in
rouge_types
}
lm_eval/tasks/unscramble.py
View file @
1f8a8c1d
...
@@ -49,6 +49,12 @@ class WordUnscrambleTask(Task):
...
@@ -49,6 +49,12 @@ class WordUnscrambleTask(Task):
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
doc
[
"context"
]
return
doc
[
"context"
]
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"context"
]
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
doc
[
"completion"
]
return
doc
[
"completion"
]
...
@@ -59,19 +65,13 @@ class WordUnscrambleTask(Task):
...
@@ -59,19 +65,13 @@ class WordUnscrambleTask(Task):
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
pred
=
results
[
0
]
pred
=
results
[
0
]
gold
=
doc
[
"completion"
]
gold
=
doc
[
"completion"
]
return
{
return
{
"acc"
:
int
(
pred
==
gold
)}
"acc"
:
int
(
pred
==
gold
)
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
class
Anagrams1
(
WordUnscrambleTask
):
class
Anagrams1
(
WordUnscrambleTask
):
...
...
lm_eval/tasks/webqs.py
View file @
1f8a8c1d
...
@@ -54,14 +54,20 @@ class WebQs(Task):
...
@@ -54,14 +54,20 @@ class WebQs(Task):
return
self
.
dataset
[
"test"
]
return
self
.
dataset
[
"test"
]
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
"Question: "
+
doc
[
'question'
]
+
'
\n
Answer:'
return
"Question: "
+
doc
[
"question"
]
+
"
\n
Answer:"
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"question"
]
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
# this picks one answer to be the "correct" one, despite sometimes
# this picks one answer to be the "correct" one, despite sometimes
# multiple correct answers being possible.
# multiple correct answers being possible.
# TODO: make sure we're actually handling multi-answer correctly
# TODO: make sure we're actually handling multi-answer correctly
return
" "
+
doc
[
'
answers
'
][
0
]
return
" "
+
doc
[
"
answers
"
][
0
]
def
_remove_prefixes
(
self
,
aliases
):
def
_remove_prefixes
(
self
,
aliases
):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
...
@@ -75,15 +81,13 @@ class WebQs(Task):
...
@@ -75,15 +81,13 @@ class WebQs(Task):
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
ret
=
[]
ret
=
[]
for
alias
in
self
.
_remove_prefixes
(
doc
[
'
answers
'
]):
for
alias
in
self
.
_remove_prefixes
(
doc
[
"
answers
"
]):
_
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
" "
+
alias
)
_
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
" "
+
alias
)
ret
.
append
(
is_prediction
)
ret
.
append
(
is_prediction
)
return
ret
return
ret
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
return
{
return
{
"acc"
:
float
(
any
(
results
))}
"acc"
:
float
(
any
(
results
))
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
...
@@ -91,6 +95,4 @@ class WebQs(Task):
...
@@ -91,6 +95,4 @@ class WebQs(Task):
}
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
lm_eval/tasks/wikitext.py
View file @
1f8a8c1d
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
Pointer Sentinel Mixture Models
Pointer Sentinel Mixture Models
https://arxiv.org/pdf/1609.07843.pdf
https://arxiv.org/pdf/1609.07843.pdf
The WikiText language modeling dataset is a collection of over 100 million tokens
The WikiText language modeling dataset is a collection of over 100 million tokens
extracted from the set of verified Good and Featured articles on Wikipedia.
extracted from the set of verified Good and Featured articles on Wikipedia.
NOTE: This `Task` is based on WikiText-2.
NOTE: This `Task` is based on WikiText-2.
...
@@ -17,7 +17,7 @@ from lm_eval.base import PerplexityTask
...
@@ -17,7 +17,7 @@ from lm_eval.base import PerplexityTask
_CITATION
=
"""
_CITATION
=
"""
@misc{merity2016pointer,
@misc{merity2016pointer,
title={Pointer Sentinel Mixture Models},
title={Pointer Sentinel Mixture Models},
author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
year={2016},
year={2016},
eprint={1609.07843},
eprint={1609.07843},
...
@@ -90,6 +90,9 @@ class WikiText(PerplexityTask):
...
@@ -90,6 +90,9 @@ class WikiText(PerplexityTask):
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
wikitext_detokenizer
(
doc
)
return
wikitext_detokenizer
(
doc
)
def
should_decontaminate
(
self
):
return
True
def
count_words
(
self
,
doc
):
def
count_words
(
self
,
doc
):
# count number of words in *original doc before detokenization*
# count number of words in *original doc before detokenization*
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
lm_eval/tasks/winogrande.py
View file @
1f8a8c1d
"""
"""
WinoGrande: An Adversarial Winograd Schema Challenge at Scale
WinoGrande: An Adversarial Winograd Schema Challenge at Scale
https://arxiv.org/pdf/1907.10641.pdf
https://arxiv.org/pdf/1907.10641.pdf
WinoGrande is a collection of 44k problems, inspired by Winograd Schema Challenge
WinoGrande is a collection of 44k problems, inspired by Winograd Schema Challenge
(Levesque, Davis, and Morgenstern 2011), but adjusted to improve the scale and
(Levesque, Davis, and Morgenstern 2011), but adjusted to improve the scale and
robustness against the dataset-specific bias. Formulated as a fill-in-a-blank
robustness against the dataset-specific bias. Formulated as a fill-in-a-blank
task with binary options, the goal is to choose the right option for a given
task with binary options, the goal is to choose the right option for a given
sentence which requires commonsense reasoning.
sentence which requires commonsense reasoning.
NOTE: This evaluation of Winogrande uses partial evaluation as described by
NOTE: This evaluation of Winogrande uses partial evaluation as described by
Trinh & Le in Simple Method for Commonsense Reasoning (2018).
Trinh & Le in Simple Method for Commonsense Reasoning (2018).
See: https://arxiv.org/abs/1806.02847
See: https://arxiv.org/abs/1806.02847
Homepage: https://leaderboard.allenai.org/winogrande/submissions/public
Homepage: https://leaderboard.allenai.org/winogrande/submissions/public
"""
"""
import
numpy
as
np
import
numpy
as
np
from
lm_eval.base
import
rf
,
Task
from
lm_eval.base
import
rf
,
Task
from
lm_eval.metrics
import
mean
from
lm_eval.metrics
import
mean
_CITATION
=
"""
_CITATION
=
"""
@article{sakaguchi2019winogrande,
@article{sakaguchi2019winogrande,
title={WinoGrande: An Adversarial Winograd Schema Challenge at Scale},
title={WinoGrande: An Adversarial Winograd Schema Challenge at Scale},
author={Sakaguchi, Keisuke and Bras, Ronan Le and Bhagavatula, Chandra and Choi, Yejin},
author={Sakaguchi, Keisuke and Bras, Ronan Le and Bhagavatula, Chandra and Choi, Yejin},
journal={arXiv preprint arXiv:1907.10641},
journal={arXiv preprint arXiv:1907.10641},
year={2019}
year={2019}
}
}
"""
"""
class
Winogrande
(
Task
):
class
Winogrande
(
Task
):
VERSION
=
0
VERSION
=
0
DATASET_PATH
=
"winogrande"
DATASET_PATH
=
"winogrande"
DATASET_NAME
=
"winogrande_xl"
DATASET_NAME
=
"winogrande_xl"
answer_to_num
=
{
'1'
:
0
,
'2'
:
1
}
answer_to_num
=
{
"1"
:
0
,
"2"
:
1
}
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
return
True
return
True
def
has_validation_docs
(
self
):
def
has_validation_docs
(
self
):
return
True
return
True
def
has_test_docs
(
self
):
def
has_test_docs
(
self
):
return
False
return
False
def
training_docs
(
self
):
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
return
self
.
_training_docs
def
validation_docs
(
self
):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
self
.
partial_context
(
doc
,
doc
[
"option"
+
doc
[
"answer"
]])
return
self
.
partial_context
(
doc
,
doc
[
"option"
+
doc
[
"answer"
]])
@
classmethod
def
should_decontaminate
(
self
):
def
partial_context
(
cls
,
doc
,
option
):
return
True
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
def
doc_to_decontamination_query
(
self
,
doc
):
pronoun_loc
=
doc
[
"sentence"
]
.
index
(
"_"
)
return
doc
[
"sentence"
]
return
doc
[
"sentence"
][:
pronoun_loc
]
+
option
@
classmethod
def
doc_to_target
(
self
,
doc
):
def
partial_context
(
cls
,
doc
,
option
):
return
self
.
partial_target
(
doc
)
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
@
classmethod
pronoun_loc
=
doc
[
"sentence"
].
index
(
"_"
)
def
partial_target
(
cls
,
doc
):
return
doc
[
"sentence"
][:
pronoun_loc
]
+
option
# The target is everything after the document specified pronoun.
pronoun_loc
=
doc
[
"sentence"
].
index
(
"_"
)
+
1
def
doc_to_target
(
self
,
doc
):
return
" "
+
doc
[
"sentence"
][
pronoun_loc
:].
strip
()
return
self
.
partial_target
(
doc
)
def
construct_requests
(
self
,
doc
,
ctx
):
@
classmethod
"""Uses RequestFactory to construct Requests and returns an iterable of
def
partial_target
(
cls
,
doc
):
Requests which will be sent to the LM.
# The target is everything after the document specified pronoun.
pronoun_loc
=
doc
[
"sentence"
].
index
(
"_"
)
+
1
:param doc:
return
" "
+
doc
[
"sentence"
][
pronoun_loc
:].
strip
()
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
def
construct_requests
(
self
,
doc
,
ctx
):
The context string, generated by fewshot_context. This includes the natural
"""Uses RequestFactory to construct Requests and returns an iterable of
language description, as
w
e
ll
as the few shot examples, and the question
Requests which
w
i
ll
be sent to the LM.
part of the document for `doc`.
"""
:param doc:
target
=
self
.
partial_target
(
doc
)
The document as returned from training_docs, validation_docs, or test_
doc
s.
lls
=
[]
:param ctx: str
for
option
in
[
doc
[
"option1"
],
doc
[
"option2"
]]:
The context string, generated by fewshot_context. This includes the natural
partial_ctx
=
self
.
partial_context
(
doc
,
op
tion
)
language description, as well as the few shot examples, and the ques
tion
full_ctx
=
self
.
append_context
(
ctx
,
partial_ctx
)
part of the document for `doc`.
lls
.
append
(
rf
.
loglikelihood
(
full_ctx
,
target
)[
0
])
"""
return
lls
target
=
self
.
partial_target
(
doc
)
lls
=
[]
@
classmethod
for
option
in
[
doc
[
"option1"
],
doc
[
"option2"
]]:
def
append_context
(
cls
,
ctx
,
partial_ctx
):
partial_ctx
=
self
.
partial_context
(
doc
,
option
)
ctx
=
ctx
.
split
(
"
\n\n
"
)
# Each fewshot context is on its own new line.
full_ctx
=
self
.
append_context
(
ctx
,
partial_ctx
)
ctx
.
pop
()
# Remove the correct context put in by `doc_to_text`.
lls
.
append
(
rf
.
loglikelihood
(
full_ctx
,
target
)[
0
])
return
"
\n\n
"
.
join
([
*
ctx
,
partial_ctx
])
if
ctx
else
partial_ctx
return
lls
def
process_results
(
self
,
doc
,
results
):
@
classmethod
"""Take a single document and the LM results and evaluates, returning a
def
append_context
(
cls
,
ctx
,
partial_ctx
):
dict where keys are the names of submetrics and values are the values of
ctx
=
ctx
.
split
(
"
\n\n
"
)
# Each fewshot context is on its own new line.
the metric for that one document
ctx
.
pop
()
# Remove the correct context put in by `doc_to_text`.
return
"
\n\n
"
.
join
([
*
ctx
,
partial_ctx
])
if
ctx
else
partial_ctx
:param doc:
The document as returned from training_docs, validation_
doc
s
,
or test_docs.
def
process_results
(
self
,
doc
,
results
):
:param results:
"""Take a single document and the LM results and evaluates, returning a
T
he
re
sults of the requests created in construct_requests.
dict w
here
keys are the names of submetrics and values are the values of
"""
the metric for that one document
return
{
"acc"
:
np
.
argmax
(
results
)
==
self
.
answer_to_num
[
doc
[
"answer"
]]
:param doc:
}
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
def
agg
re
g
at
ion
(
self
):
The results of the requests c
reat
ed in construct_requests.
"""
"""
:
return
s: {str: [float] -> float}
return
{
"acc"
:
np
.
argmax
(
results
)
==
self
.
answer_to_num
[
doc
[
"answer"
]]}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
def
aggregation
(
self
):
"""
"""
return
{
:
return
s: {str: [float] -> float}
"acc"
:
mean
A dictionary where keys are the names of submetrics and values are
}
functions that aggregate a list of metrics
"""
def
higher_is_better
(
self
):
return
{
"acc"
:
mean
}
"""
:returns: {str: bool}
def
higher_is_better
(
self
):
A dictionary where keys are the names of submetrics and values are
"""
whether a higher value of the submetric is better
:returns: {str: bool}
"""
A dictionary where keys are the names of submetrics and values are
return
{
whether a higher value of the submetric is better
"acc"
:
True
"""
}
return
{
"acc"
:
True
}
lm_eval/tasks/wsc273.py
View file @
1f8a8c1d
...
@@ -40,8 +40,19 @@ class WinogradSchemaChallenge273(Task):
...
@@ -40,8 +40,19 @@ class WinogradSchemaChallenge273(Task):
DATASET_PATH
=
"winograd_wsc"
DATASET_PATH
=
"winograd_wsc"
DATASET_NAME
=
"wsc273"
DATASET_NAME
=
"wsc273"
upper_pronouns
=
[
"A"
,
"An"
,
"The"
,
"She"
,
"He"
,
upper_pronouns
=
[
"It"
,
"They"
,
"My"
,
"His"
,
"Her"
,
"Their"
]
"A"
,
"An"
,
"The"
,
"She"
,
"He"
,
"It"
,
"They"
,
"My"
,
"His"
,
"Her"
,
"Their"
,
]
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
return
False
return
False
...
@@ -68,7 +79,7 @@ class WinogradSchemaChallenge273(Task):
...
@@ -68,7 +79,7 @@ class WinogradSchemaChallenge273(Task):
option
+=
"'s"
option
+=
"'s"
# Appropriately lowercase the pronoun in the option.
# Appropriately lowercase the pronoun in the option.
pronoun
=
option
.
split
()[
0
]
pronoun
=
option
.
split
()[
0
]
start_of_sentence
=
doc
[
"text"
][
doc
[
'
pronoun_loc
'
]
-
2
]
==
'.'
start_of_sentence
=
doc
[
"text"
][
doc
[
"
pronoun_loc
"
]
-
2
]
==
"."
if
not
start_of_sentence
and
pronoun
in
self
.
upper_pronouns
:
if
not
start_of_sentence
and
pronoun
in
self
.
upper_pronouns
:
return
option
.
replace
(
pronoun
,
pronoun
.
lower
())
return
option
.
replace
(
pronoun
,
pronoun
.
lower
())
return
option
return
option
...
@@ -85,11 +96,17 @@ class WinogradSchemaChallenge273(Task):
...
@@ -85,11 +96,17 @@ class WinogradSchemaChallenge273(Task):
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
self
.
partial_context
(
doc
,
doc
[
"options"
][
doc
[
"label"
]])
return
self
.
partial_context
(
doc
,
doc
[
"options"
][
doc
[
"label"
]])
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"text"
]
@
classmethod
@
classmethod
def
partial_context
(
cls
,
doc
,
option
):
def
partial_context
(
cls
,
doc
,
option
):
# Substitute the pronoun in the original text with the specified
# Substitute the pronoun in the original text with the specified
# option and ignore everything after.
# option and ignore everything after.
return
doc
[
"text"
][:
doc
[
"pronoun_loc"
]]
+
option
return
doc
[
"text"
][:
doc
[
"pronoun_loc"
]]
+
option
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
self
.
partial_target
(
doc
)
return
self
.
partial_target
(
doc
)
...
@@ -135,9 +152,7 @@ class WinogradSchemaChallenge273(Task):
...
@@ -135,9 +152,7 @@ class WinogradSchemaChallenge273(Task):
:param results:
:param results:
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
return
{
return
{
"acc"
:
np
.
argmax
(
results
)
==
doc
[
"label"
]}
"acc"
:
np
.
argmax
(
results
)
==
doc
[
"label"
]
}
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
"""
...
@@ -145,9 +160,7 @@ class WinogradSchemaChallenge273(Task):
...
@@ -145,9 +160,7 @@ class WinogradSchemaChallenge273(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
functions that aggregate a list of metrics
"""
"""
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
...
@@ -155,6 +168,4 @@ class WinogradSchemaChallenge273(Task):
...
@@ -155,6 +168,4 @@ class WinogradSchemaChallenge273(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
whether a higher value of the submetric is better
"""
"""
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
lm_eval/utils.py
View file @
1f8a8c1d
...
@@ -34,6 +34,7 @@ def simple_parse_args_string(args_string):
...
@@ -34,6 +34,7 @@ def simple_parse_args_string(args_string):
args_dict
[
k
]
=
v
args_dict
[
k
]
=
v
return
args_dict
return
args_dict
def
join_iters
(
iters
):
def
join_iters
(
iters
):
for
iter
in
iters
:
for
iter
in
iters
:
yield
from
iter
yield
from
iter
...
@@ -46,23 +47,26 @@ def chunks(iter, n):
...
@@ -46,23 +47,26 @@ def chunks(iter, n):
if
len
(
arr
)
==
n
:
if
len
(
arr
)
==
n
:
yield
arr
yield
arr
arr
=
[]
arr
=
[]
if
arr
:
yield
arr
if
arr
:
yield
arr
def
group
(
arr
,
fn
):
def
group
(
arr
,
fn
):
res
=
collections
.
defaultdict
(
list
)
res
=
collections
.
defaultdict
(
list
)
for
ob
in
arr
:
for
ob
in
arr
:
res
[
fn
(
ob
)].
append
(
ob
)
res
[
fn
(
ob
)].
append
(
ob
)
return
list
(
res
.
values
())
return
list
(
res
.
values
())
def
general_detokenize
(
string
):
def
general_detokenize
(
string
):
string
=
string
.
replace
(
" n't"
,
"n't"
)
string
=
string
.
replace
(
" n't"
,
"n't"
)
string
=
string
.
replace
(
" )"
,
")"
)
string
=
string
.
replace
(
" )"
,
")"
)
string
=
string
.
replace
(
"( "
,
"("
)
string
=
string
.
replace
(
"( "
,
"("
)
string
=
string
.
replace
(
"
\
"
"
,
"
\"
"
)
string
=
string
.
replace
(
'
"
'
,
'"'
)
string
=
string
.
replace
(
"
\"
"
,
"
\"
"
)
string
=
string
.
replace
(
' "'
,
'"'
)
string
=
re
.
sub
(
r
" (['.,])"
,
r
"\1"
,
string
)
string
=
re
.
sub
(
r
" (['.,])"
,
r
"\1"
,
string
)
return
string
return
string
...
@@ -94,10 +98,7 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
...
@@ -94,10 +98,7 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
# Special handling for first window: predict all tokens
# Special handling for first window: predict all tokens
first_seq_len
=
min
(
max_seq_len
,
len
(
token_list
))
first_seq_len
=
min
(
max_seq_len
,
len
(
token_list
))
yield
(
yield
([
prefix_token
]
+
token_list
[:
first_seq_len
-
1
],
token_list
[:
first_seq_len
])
[
prefix_token
]
+
token_list
[:
first_seq_len
-
1
],
token_list
[:
first_seq_len
]
)
predicted
+=
first_seq_len
predicted
+=
first_seq_len
while
predicted
<
len
(
token_list
):
while
predicted
<
len
(
token_list
):
...
@@ -105,61 +106,66 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
...
@@ -105,61 +106,66 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
window_end
=
predicted
+
window_pred_len
window_end
=
predicted
+
window_pred_len
yield
(
yield
(
token_list
[
window_end
-
max_seq_len
-
1
:
window_end
-
1
],
token_list
[
window_end
-
max_seq_len
-
1
:
window_end
-
1
],
token_list
[
window_end
-
window_pred_len
:
window_end
],
token_list
[
window_end
-
window_pred_len
:
window_end
],
)
)
predicted
+=
window_pred_len
predicted
+=
window_pred_len
def
make_disjoint_window
(
pair
):
def
make_disjoint_window
(
pair
):
"""
Takes output from get_rolling_token_windows and makes the context not overlap with the continuation
"""
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a
,
b
=
pair
a
,
b
=
pair
return
a
[:
-
(
len
(
b
)
-
1
)],
b
return
a
[:
-
(
len
(
b
)
-
1
)],
b
class
Reorderer
:
class
Reorderer
:
def
__init__
(
self
,
arr
,
fn
):
def
__init__
(
self
,
arr
,
fn
):
self
.
size
=
len
(
arr
)
self
.
size
=
len
(
arr
)
arr
=
list
(
enumerate
(
arr
))
arr
=
list
(
enumerate
(
arr
))
arr
=
group
(
arr
,
lambda
x
:
fn
(
x
[
1
]))
arr
=
group
(
arr
,
lambda
x
:
fn
(
x
[
1
]))
arr
=
[
arr
=
[([
y
[
0
]
for
y
in
x
],
x
[
0
][
1
])
for
x
in
arr
]
([
y
[
0
]
for
y
in
x
],
x
[
0
][
1
])
for
x
in
arr
]
arr
.
sort
(
key
=
lambda
x
:
fn
(
x
[
1
]))
arr
.
sort
(
key
=
lambda
x
:
fn
(
x
[
1
]))
self
.
arr
=
arr
self
.
arr
=
arr
def
get_reordered
(
self
):
def
get_reordered
(
self
):
return
[
x
[
1
]
for
x
in
self
.
arr
]
return
[
x
[
1
]
for
x
in
self
.
arr
]
def
get_original
(
self
,
newarr
):
def
get_original
(
self
,
newarr
):
res
=
[
None
]
*
self
.
size
res
=
[
None
]
*
self
.
size
cov
=
[
False
]
*
self
.
size
cov
=
[
False
]
*
self
.
size
for
(
inds
,
_
),
v
in
zip
(
self
.
arr
,
newarr
):
for
(
inds
,
_
),
v
in
zip
(
self
.
arr
,
newarr
):
for
ind
in
inds
:
for
ind
in
inds
:
res
[
ind
]
=
v
res
[
ind
]
=
v
cov
[
ind
]
=
True
cov
[
ind
]
=
True
assert
all
(
cov
)
assert
all
(
cov
)
return
res
return
res
def
positional_deprecated
(
fn
):
def
positional_deprecated
(
fn
):
"""
"""
A decorator to nudge users into passing only keyword args (`kwargs`) to the
A decorator to nudge users into passing only keyword args (`kwargs`) to the
wrapped function, `fn`.
wrapped function, `fn`.
"""
"""
@
functools
.
wraps
(
fn
)
@
functools
.
wraps
(
fn
)
def
_wrapper
(
*
args
,
**
kwargs
):
def
_wrapper
(
*
args
,
**
kwargs
):
if
len
(
args
)
!=
1
if
inspect
.
ismethod
(
fn
)
else
0
:
if
len
(
args
)
!=
1
if
inspect
.
ismethod
(
fn
)
else
0
:
print
(
f
"WARNING: using
{
fn
.
__name__
}
with positional arguments is "
print
(
f
"WARNING: using
{
fn
.
__name__
}
with positional arguments is "
"deprecated and will be disallowed in a future version of "
"deprecated and will be disallowed in a future version of "
"lm-evaluation-harness!"
)
"lm-evaluation-harness!"
)
return
fn
(
*
args
,
**
kwargs
)
return
fn
(
*
args
,
**
kwargs
)
return
_wrapper
return
_wrapper
@
positional_deprecated
@
positional_deprecated
def
find_test_root
(
start_path
:
pathlib
.
Path
)
->
pathlib
.
Path
:
def
find_test_root
(
start_path
:
pathlib
.
Path
)
->
pathlib
.
Path
:
"""
"""
...
@@ -169,12 +175,14 @@ def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
...
@@ -169,12 +175,14 @@ def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
cur_path
=
start_path
.
resolve
()
cur_path
=
start_path
.
resolve
()
max_layers
=
3
max_layers
=
3
for
_
in
range
(
max_layers
):
for
_
in
range
(
max_layers
):
if
(
cur_path
/
'
tests
'
/
'
test_version_stable.py
'
).
exists
():
if
(
cur_path
/
"
tests
"
/
"
test_version_stable.py
"
).
exists
():
return
cur_path
return
cur_path
else
:
else
:
cur_path
=
cur_path
.
parent
.
resolve
()
cur_path
=
cur_path
.
parent
.
resolve
()
raise
FileNotFoundError
(
f
"Unable to find package root within
{
max_layers
}
upwards"
+
\
raise
FileNotFoundError
(
f
"of
{
start_path
}
"
)
f
"Unable to find package root within
{
max_layers
}
upwards"
+
f
"of
{
start_path
}
"
)
@
positional_deprecated
@
positional_deprecated
def
run_task_tests
(
task_list
:
List
[
str
]):
def
run_task_tests
(
task_list
:
List
[
str
]):
...
@@ -182,9 +190,16 @@ def run_task_tests(task_list: List[str]):
...
@@ -182,9 +190,16 @@ def run_task_tests(task_list: List[str]):
Find the package root and run the tests for the given tasks
Find the package root and run the tests for the given tasks
"""
"""
package_root
=
find_test_root
(
start_path
=
pathlib
.
Path
(
__file__
))
package_root
=
find_test_root
(
start_path
=
pathlib
.
Path
(
__file__
))
task_string
=
' or '
.
join
(
task_list
)
task_string
=
" or "
.
join
(
task_list
)
args
=
[
f
'
{
package_root
}
/tests/test_version_stable.py'
,
f
'--rootdir=
{
package_root
}
'
,
'-k'
,
f
'
{
task_string
}
'
]
args
=
[
f
"
{
package_root
}
/tests/test_version_stable.py"
,
f
"--rootdir=
{
package_root
}
"
,
"-k"
,
f
"
{
task_string
}
"
,
]
sys
.
path
.
append
(
str
(
package_root
))
sys
.
path
.
append
(
str
(
package_root
))
pytest_return_val
=
pytest
.
main
(
args
)
pytest_return_val
=
pytest
.
main
(
args
)
if
pytest_return_val
:
if
pytest_return_val
:
raise
ValueError
(
f
"Not all tests for the specified tasks (
{
task_list
}
) ran successfully! Error code:
{
pytest_return_val
}
"
)
raise
ValueError
(
\ No newline at end of file
f
"Not all tests for the specified tasks (
{
task_list
}
) ran successfully! Error code:
{
pytest_return_val
}
"
)
main.py
View file @
1f8a8c1d
import
argparse
import
argparse
import
json
import
json
import
logging
import
logging
import
fnmatch
from
lm_eval
import
tasks
,
evaluator
from
lm_eval
import
tasks
,
evaluator
logging
.
getLogger
(
"openai"
).
setLevel
(
logging
.
WARNING
)
logging
.
getLogger
(
"openai"
).
setLevel
(
logging
.
WARNING
)
class
MultiChoice
:
def
__init__
(
self
,
choices
):
self
.
choices
=
choices
# Simple wildcard support (linux filename patterns)
def
__contains__
(
self
,
values
):
for
value
in
values
.
split
(
","
):
if
len
(
fnmatch
.
filter
(
self
.
choices
,
value
))
==
0
:
return
False
return
True
def
__iter__
(
self
):
for
choice
in
self
.
choices
:
yield
choice
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model'
,
required
=
True
)
parser
.
add_argument
(
"--model"
,
required
=
True
)
parser
.
add_argument
(
'--model_args'
,
default
=
""
)
parser
.
add_argument
(
"--model_args"
,
default
=
""
)
parser
.
add_argument
(
'--tasks'
,
default
=
"all_tasks"
)
parser
.
add_argument
(
"--tasks"
,
default
=
None
,
choices
=
MultiChoice
(
tasks
.
ALL_TASKS
))
parser
.
add_argument
(
'--provide_description'
,
action
=
"store_true"
)
parser
.
add_argument
(
"--provide_description"
,
action
=
"store_true"
)
parser
.
add_argument
(
'--num_fewshot'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--num_fewshot"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--output_path'
,
default
=
None
)
parser
.
add_argument
(
"--output_path"
,
default
=
None
)
parser
.
add_argument
(
'--limit'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--limit"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--no_cache'
,
action
=
"store_true"
)
parser
.
add_argument
(
"--no_cache"
,
action
=
"store_true"
)
parser
.
add_argument
(
'--description_dict_path'
,
default
=
None
)
parser
.
add_argument
(
"--decontamination_ngrams_path"
,
default
=
None
)
parser
.
add_argument
(
'--check_integrity'
,
action
=
"store_true"
)
parser
.
add_argument
(
"--description_dict_path"
,
default
=
None
)
parser
.
add_argument
(
"--check_integrity"
,
action
=
"store_true"
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def
pattern_match
(
patterns
,
source_list
):
task_names
=
set
()
for
pattern
in
patterns
:
for
matching
in
fnmatch
.
filter
(
source_list
,
pattern
):
task_names
.
add
(
matching
)
return
list
(
task_names
)
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
assert
not
args
.
provide_description
# not implemented
assert
not
args
.
provide_description
# not implemented
if
args
.
limit
:
if
args
.
limit
:
print
(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
print
(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if
args
.
tasks
==
"all_tasks"
:
if
args
.
tasks
is
None
:
task_names
=
tasks
.
ALL_TASKS
task_names
=
tasks
.
ALL_TASKS
else
:
else
:
task_names
=
args
.
tasks
.
split
(
","
)
task_names
=
pattern_match
(
args
.
tasks
.
split
(
","
),
tasks
.
ALL_TASKS
)
print
(
f
"Selected Tasks:
{
task_names
}
"
)
description_dict
=
{}
description_dict
=
{}
if
args
.
description_dict_path
:
if
args
.
description_dict_path
:
with
open
(
args
.
description_dict_path
,
'r'
)
as
f
:
with
open
(
args
.
description_dict_path
,
"r"
)
as
f
:
description_dict
=
json
.
load
(
f
)
description_dict
=
json
.
load
(
f
)
results
=
evaluator
.
simple_evaluate
(
results
=
evaluator
.
simple_evaluate
(
...
@@ -51,11 +86,11 @@ def main():
...
@@ -51,11 +86,11 @@ def main():
no_cache
=
args
.
no_cache
,
no_cache
=
args
.
no_cache
,
limit
=
args
.
limit
,
limit
=
args
.
limit
,
description_dict
=
description_dict
,
description_dict
=
description_dict
,
check_integrity
=
args
.
check_integrity
decontamination_ngrams_path
=
args
.
decontamination_ngrams_path
,
check_integrity
=
args
.
check_integrity
,
)
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
print
(
dumped
)
print
(
dumped
)
if
args
.
output_path
:
if
args
.
output_path
:
...
...
pile_statistics.json
0 → 100644
View file @
1f8a8c1d
{
"Data"
:
"Pile statistics"
,
"Document Count"
:
210607728
,
"Total Pile Characters"
:
421215456
,
"File Start Offsets"
:
[
0
,
7021438
,
14042822
,
21066113
,
28086515
,
35106072
,
42123306
,
49145091
,
56165817
,
63185587
,
70211208
,
77234322
,
84249267
,
91267634
,
98285983
,
105305110
,
112322489
,
119342491
,
126367373
,
133389153
,
140412039
,
147432373
,
154452516
,
161470190
,
168492733
,
175512521
,
182526939
,
189547478
,
196565318
,
203583306
]
}
scripts/clean_training_data/README.md
View file @
1f8a8c1d
janitor.py contains a script to remove benchmark data contamination from training data sets.
janitor.py contains a script to remove benchmark data contamination from training data sets.
It uses the approach described in the
[
GPT-3 paper
](
https://arxiv.org/abs/2005.14165
)
.
It uses the approach described in the
[
GPT-3 paper
](
https://arxiv.org/abs/2005.14165
)
.
## Algorithm
## Algorithm
1) Collects all contamination text files that are to be removed from training data
1) Collects all contamination text files that are to be removed from training data
2) Filters training data by finding
`N`
gram matches between the training data
2) Filters training data by finding
`N`
gram matches between the training data
and any contamination
and any contamination
1)
`N`
grams ignore case and punctation and are split on whitespace.
1)
`N`
grams ignore case and punct
u
ation and are split on whitespace.
2) Matching
`N`
gram substrings are removed, as is a
`window_to_remove`
character window around
2) Matching
`N`
gram substrings are removed, as is a
`window_to_remove`
character window around
the match, splitting the training data into chunks
the match, splitting the training data into chunks
3) Any chunks less than
`minimum_slice_length`
are removed
3) Any chunks less than
`minimum_slice_length`
are removed
4) Training data sets split into more than
`too_dirty_cutoff`
are considered
4) Training data sets split into more than
`too_dirty_cutoff`
are considered
completey contaminated and removed
completey contaminated and removed
OpenAI used:
OpenAI used:
```
```
ngram_n = 13
ngram_n = 13
...
@@ -20,7 +20,7 @@ minimum_slice_length = 200
...
@@ -20,7 +20,7 @@ minimum_slice_length = 200
too_dirty_cutoff = 10
too_dirty_cutoff = 10
```
```
## Compling
## Comp
i
ling
Janitor can be used as a pure python program, but it is much faster if the ngram
Janitor can be used as a pure python program, but it is much faster if the ngram
code is run in C++. To compile the C++ code, run
code is run in C++. To compile the C++ code, run
...
@@ -31,4 +31,3 @@ c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor
...
@@ -31,4 +31,3 @@ c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor
```
```
If your your compiler isn't linked to python, you may need to add to the above
`-undefined dynamic_lookup`
If your your compiler isn't linked to python, you may need to add to the above
`-undefined dynamic_lookup`
scripts/clean_training_data/compress_and_package.py
0 → 100644
View file @
1f8a8c1d
import
glob
import
argparse
import
os
import
subprocess
import
shutil
from
tqdm
import
tqdm
from
tqdm_multiprocess
import
TqdmMultiProcessPool
import
logging
from
tqdm_multiprocess.logger
import
setup_logger_tqdm
logger
=
logging
.
getLogger
(
__name__
)
def
process_task
(
working_directory
,
output_directory
,
bucket_file_path
,
tqdm_func
,
global_tqdm
):
command
=
f
"zstd
{
bucket_file_path
}
"
logger
.
info
(
command
)
subprocess
.
call
(
command
,
shell
=
True
)
compressed_file
=
bucket_file_path
+
".zst"
if
output_directory
:
shutil
.
move
(
compressed_file
,
output_directory
)
os
.
remove
(
bucket_file_path
)
global_tqdm
.
update
()
def
compress_and_move
(
working_directory
,
output_directory
,
process_count
):
os
.
makedirs
(
output_directory
,
exist_ok
=
True
)
original_info_file_path
=
os
.
path
.
join
(
working_directory
,
"info.json"
)
assert
os
.
path
.
exists
(
original_info_file_path
)
tasks
=
[]
bucket_file_paths
=
glob
.
glob
(
os
.
path
.
join
(
working_directory
,
"output"
,
f
"*.bkt.txt.sorted"
)
)
for
bucket_file_path
in
bucket_file_paths
:
task
=
(
process_task
,
(
working_directory
,
output_directory
,
bucket_file_path
))
tasks
.
append
(
task
)
pool
=
TqdmMultiProcessPool
(
process_count
)
def
on_done
(
_
):
return
None
def
on_error
(
_
):
return
None
global_progress
=
tqdm
(
total
=
len
(
bucket_file_paths
),
dynamic_ncols
=
True
,
unit
=
"file"
)
_
=
pool
.
map
(
global_progress
,
tasks
,
on_error
,
on_done
)
shutil
.
copy
(
original_info_file_path
,
os
.
path
.
join
(
output_directory
,
"info.json"
))
parser
=
argparse
.
ArgumentParser
(
description
=
"sort 13gram buckets"
)
parser
.
add_argument
(
"-dir"
,
"--working_directory"
,
required
=
True
)
parser
.
add_argument
(
"-output"
,
"--output_directory"
,
required
=
True
)
parser
.
add_argument
(
"-procs"
,
"--process_count"
,
type
=
int
,
default
=
8
)
if
__name__
==
"__main__"
:
version
=
1.00
print
(
f
"Running version
{
version
}
"
)
logfile_path
=
"compress_and_package.log"
setup_logger_tqdm
(
logfile_path
)
args
=
parser
.
parse_args
()
compress_and_move
(
args
.
working_directory
,
args
.
output_directory
,
args
.
process_count
)
Prev
1
2
3
4
5
6
7
8
9
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment