Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
6caa0afd
Unverified
Commit
6caa0afd
authored
Apr 01, 2022
by
Leo Gao
Committed by
GitHub
Apr 01, 2022
Browse files
Merge pull request #300 from jon-tow/hf-dataset-refactor
Refactor `Task` downloading to use `HuggingFace.datasets`
parents
7064d6b9
9434722c
Changes
87
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
405 additions
and
602 deletions
+405
-602
lm_eval/tasks/anli.py
lm_eval/tasks/anli.py
+9
-7
lm_eval/tasks/arc.py
lm_eval/tasks/arc.py
+13
-3
lm_eval/tasks/arithmetic.py
lm_eval/tasks/arithmetic.py
+26
-60
lm_eval/tasks/asdiv.py
lm_eval/tasks/asdiv.py
+9
-43
lm_eval/tasks/blimp.py
lm_eval/tasks/blimp.py
+12
-8
lm_eval/tasks/cbt.py
lm_eval/tasks/cbt.py
+21
-3
lm_eval/tasks/common.py
lm_eval/tasks/common.py
+0
-52
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+14
-23
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+33
-36
lm_eval/tasks/glue.py
lm_eval/tasks/glue.py
+88
-17
lm_eval/tasks/gsm8k.py
lm_eval/tasks/gsm8k.py
+6
-23
lm_eval/tasks/headqa.py
lm_eval/tasks/headqa.py
+19
-4
lm_eval/tasks/hellaswag.py
lm_eval/tasks/hellaswag.py
+18
-11
lm_eval/tasks/hendrycks_ethics.py
lm_eval/tasks/hendrycks_ethics.py
+59
-93
lm_eval/tasks/hendrycks_math.py
lm_eval/tasks/hendrycks_math.py
+19
-45
lm_eval/tasks/hendrycks_test.py
lm_eval/tasks/hendrycks_test.py
+19
-49
lm_eval/tasks/lambada.py
lm_eval/tasks/lambada.py
+4
-20
lm_eval/tasks/lambada_cloze.py
lm_eval/tasks/lambada_cloze.py
+1
-5
lm_eval/tasks/lambada_multilingual.py
lm_eval/tasks/lambada_multilingual.py
+17
-55
lm_eval/tasks/logiqa.py
lm_eval/tasks/logiqa.py
+18
-45
No files found.
lm_eval/tasks/anli.py
View file @
6caa0afd
...
...
@@ -10,9 +10,8 @@ provided explanations.
Homepage: "https://github.com/facebookresearch/anli"
"""
import
numpy
as
np
from
lm_eval.base
import
rf
from
..metrics
import
mean
from
.
common
import
HFTask
from
lm_eval.base
import
rf
,
Task
from
lm_eval.metrics
import
mean
_CITATION
=
"""
...
...
@@ -31,7 +30,7 @@ _CITATION = """
"""
class
ANLIBase
(
HF
Task
):
class
ANLIBase
(
Task
):
VERSION
=
0
DATASET_PATH
=
"anli"
DATASET_NAME
=
None
...
...
@@ -49,16 +48,16 @@ class ANLIBase(HFTask):
def
training_docs
(
self
):
if
self
.
has_training_docs
():
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
data
[
"train_r"
+
str
(
self
.
SPLIT
)])
self
.
_training_docs
=
list
(
self
.
data
set
[
"train_r"
+
str
(
self
.
SPLIT
)])
return
self
.
_training_docs
def
validation_docs
(
self
):
if
self
.
has_validation_docs
():
return
self
.
data
[
"dev_r"
+
str
(
self
.
SPLIT
)]
return
self
.
data
set
[
"dev_r"
+
str
(
self
.
SPLIT
)]
def
test_docs
(
self
):
if
self
.
has_test_docs
():
return
self
.
data
[
"test_r"
+
str
(
self
.
SPLIT
)]
return
self
.
data
set
[
"test_r"
+
str
(
self
.
SPLIT
)]
def
doc_to_text
(
self
,
doc
):
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
...
...
@@ -125,11 +124,14 @@ class ANLIBase(HFTask):
"acc"
:
True
}
class
ANLIRound1
(
ANLIBase
):
SPLIT
=
1
class
ANLIRound2
(
ANLIBase
):
SPLIT
=
2
class
ANLIRound3
(
ANLIBase
):
SPLIT
=
3
lm_eval/tasks/arc.py
View file @
6caa0afd
...
...
@@ -13,7 +13,6 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc
"""
from
lm_eval.base
import
MultipleChoiceTask
from
.
common
import
HFTask
_CITATION
=
"""
...
...
@@ -27,7 +26,7 @@ _CITATION = """
"""
class
ARCEasy
(
HFTask
,
MultipleChoiceTask
):
class
ARCEasy
(
MultipleChoiceTask
):
VERSION
=
0
DATASET_PATH
=
"ai2_arc"
DATASET_NAME
=
"ARC-Easy"
...
...
@@ -41,7 +40,18 @@ class ARCEasy(HFTask, MultipleChoiceTask):
def
has_test_docs
(
self
):
return
True
def
_convert_standard
(
self
,
doc
):
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
map
(
self
.
_process_doc
,
self
.
dataset
[
"train"
]))
return
self
.
_training_docs
def
validation_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"validation"
])
def
test_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"test"
])
def
_process_doc
(
self
,
doc
):
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter
=
{
"1"
:
"A"
,
"2"
:
"B"
,
"3"
:
"C"
,
"4"
:
"D"
,
"5"
:
"E"
}
...
...
lm_eval/tasks/arithmetic.py
View file @
6caa0afd
...
...
@@ -7,13 +7,10 @@ problem in natural language.
Homepage: https://github.com/openai/gpt-3/tree/master/data
"""
import
abc
import
json
import
os
from
collections
import
namedtuple
import
inspect
import
lm_eval.datasets.arithmetic.arithmetic
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
from
best_download
import
download_file
_CITATION
=
"""
...
...
@@ -31,33 +28,9 @@ _CITATION = """
"""
ArithmeticDoc
=
namedtuple
(
'ArithmeticDoc'
,
[
'context'
,
'completion'
])
class
Arithmetic
(
Task
):
VERSION
=
0
directory
=
'data/arithmetic/'
def
__init__
(
self
):
super
().
__init__
()
def
download
(
self
):
file_name
,
checksum
=
self
.
get_file_download_info
()
url
=
'https://raw.githubusercontent.com/openai/gpt-3/master/data/'
+
file_name
if
not
os
.
path
.
exists
(
self
.
directory
):
os
.
makedirs
(
self
.
directory
)
download_file
(
url
,
local_file
=
self
.
directory
+
file_name
,
expected_checksum
=
checksum
)
self
.
set_docs
()
@
abc
.
abstractmethod
def
get_file_download_info
(
self
):
"""returns a tuple of (file_name, checksum)"""
pass
def
set_docs
(
self
):
file_name
,
_
=
self
.
get_file_download_info
()
jsons
=
open
(
self
.
directory
+
file_name
,
'r'
)
self
.
_docs
=
[
self
.
load_doc
(
json
.
loads
(
line
))
for
line
in
jsons
]
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
arithmetic
.
arithmetic
)
def
has_training_docs
(
self
):
return
False
...
...
@@ -72,25 +45,19 @@ class Arithmetic(Task):
return
NotImplemented
def
validation_docs
(
self
):
return
self
.
_docs
return
self
.
dataset
[
"validation"
]
def
test_docs
(
self
):
return
NotImplemented
def
doc_to_text
(
self
,
doc
):
return
doc
.
context
return
doc
[
"
context
"
]
def
doc_to_target
(
self
,
doc
):
return
doc
.
completion
return
doc
[
"
completion
"
]
def
load_doc
(
self
,
doc_json
):
return
ArithmeticDoc
(
context
=
doc_json
[
'context'
].
strip
()
.
replace
(
'
\n\n
'
,
'
\n
'
)
.
replace
(
'Q:'
,
'Question:'
)
.
replace
(
'A:'
,
'Answer:'
),
completion
=
doc_json
[
'completion'
])
def
construct_requests
(
self
,
doc
,
ctx
):
ll
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
doc
.
completion
)
ll
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
doc
[
"
completion
"
]
)
return
is_prediction
def
process_results
(
self
,
doc
,
results
):
...
...
@@ -111,41 +78,40 @@ class Arithmetic(Task):
class
Arithmetic2DPlus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'two_digit_addition.jsonl'
,
'75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
DATASET_NAME
=
"arithmetic_2da"
class
Arithmetic2DMinus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'two_digit_subtraction.jsonl'
,
'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
DATASET_NAME
=
"arithmetic_2ds"
class
Arithmetic3DPlus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'three_digit_addition.jsonl'
,
'124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
DATASET_NAME
=
"arithmetic_3da"
class
Arithmetic3DMinus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'three_digit_subtraction.jsonl'
,
'7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
DATASET_NAME
=
"arithmetic_3ds"
class
Arithmetic4DPlus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'four_digit_addition.jsonl'
,
'459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
DATASET_NAME
=
"arithmetic_4da"
class
Arithmetic4DMinus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'four_digit_subtraction.jsonl'
,
'0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
DATASET_NAME
=
"arithmetic_4ds"
class
Arithmetic5DPlus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'five_digit_addition.jsonl'
,
'30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
DATASET_NAME
=
"arithmetic_5da"
class
Arithmetic5DMinus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'five_digit_subtraction.jsonl'
,
'8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
DATASET_NAME
=
"arithmetic_5ds"
class
Arithmetic2DMultiplication
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'two_digit_multiplication.jsonl'
,
'5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
DATASET_NAME
=
"arithmetic_2dm"
class
Arithmetic1DComposite
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'single_digit_three_ops.jsonl'
,
'08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'
DATASET_NAME
=
"arithmetic_1dc"
lm_eval/tasks/asdiv.py
View file @
6caa0afd
...
...
@@ -14,15 +14,10 @@ NOTE: We currently ignore formulas for answer generation.
Homepage: https://github.com/chaochun/nlu-asdiv-dataset
"""
from
lm_eval.base
import
Task
from
pathlib
import
Path
from
best_download
import
download_file
import
xml.etree.ElementTree
as
ET
from
lm_eval.base
import
rf
from
lm_eval.metrics
import
mean
,
perplexity
import
numpy
as
np
from
zipfile
import
ZipFile
import
os
import
inspect
import
lm_eval.datasets.asdiv.asdiv
from
lm_eval.base
import
rf
,
Task
from
lm_eval.metrics
import
mean
_CITATION
=
"""
...
...
@@ -39,39 +34,11 @@ _CITATION = """
class
Asdiv
(
Task
):
VERSION
=
0
DATASET_PATH
=
Path
(
"data/asdiv"
)
def
download
(
self
):
if
self
.
DATASET_PATH
.
exists
():
return
Path
.
mkdir
(
self
.
DATASET_PATH
,
parents
=
True
)
url
=
"https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccfa5053194b25732534696b50.zip"
checksum
=
"8f1fe4f6d5f170ec1e24ab78c244153c14c568b1bb2b1dad0324e71f37939a2d"
zip_path
=
self
.
DATASET_PATH
/
"55790e5270bb91ccfa5053194b25732534696b50.zip"
download_file
(
url
,
local_file
=
str
(
zip_path
),
expected_checksum
=
checksum
)
with
ZipFile
(
zip_path
,
"r"
)
as
zip
:
zip
.
extractall
(
self
.
DATASET_PATH
)
os
.
remove
(
zip_path
)
def
_convert_standard
(
self
,
problem
):
#TODO: include solution-type and formula
out_doc
=
{
"question"
:
problem
.
find
(
'Question'
).
text
,
"body"
:
problem
.
find
(
'Body'
).
text
,
"answer"
:
problem
.
find
(
'Answer'
).
text
}
return
out_doc
def
load_docs
(
self
,
textfilename
,
tfds
=
False
):
tree
=
ET
.
parse
(
textfilename
)
root
=
tree
.
getroot
()
for
pid
,
problem
in
enumerate
(
root
.
iter
(
'Problem'
)):
out_doc
=
self
.
_convert_standard
(
problem
)
yield
out_doc
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
asdiv
.
asdiv
)
def
has_training_docs
(
self
):
return
False
def
has_validation_docs
(
self
):
return
True
...
...
@@ -81,13 +48,12 @@ class Asdiv(Task):
def
training_docs
(
self
):
raise
NotImplementedError
(
"This dataset has no training docs"
)
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
test_docs
(
self
):
raise
NotImplementedError
(
"This dataset has no test docs"
)
def
validation_docs
(
self
):
data_xml_path
=
self
.
DATASET_PATH
/
"nlu-asdiv-dataset-55790e5270bb91ccfa5053194b25732534696b50/dataset/ASDiv.xml"
return
self
.
load_docs
(
data_xml_path
)
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
num_fewshot
==
0
,
"ASDiv is intended only for the zero-shot setting."
return
super
().
fewshot_context
(
...
...
lm_eval/tasks/blimp.py
View file @
6caa0afd
...
...
@@ -10,9 +10,8 @@ grammars.
Homepage: https://github.com/alexwarstadt/blimp
"""
from
lm_eval.base
import
rf
from
lm_eval.base
import
rf
,
Task
from
lm_eval.metrics
import
mean
from
.common
import
HFTask
_CITATION
=
"""
...
...
@@ -32,19 +31,24 @@ _CITATION = """
"""
class
BlimpTask
(
HF
Task
):
class
BlimpTask
(
Task
):
VERSION
=
0
DATASET_PATH
=
"blimp"
def
download
(
self
):
super
().
download
()
def
has_training_docs
(
self
):
return
False
def
has_validation_docs
(
self
):
return
True
def
has_test_docs
(
self
):
return
False
def
validation_docs
(
self
):
# The HF dataset only contains a "train" dataset, but the harness expects a "validation"
# dataset. Let's use the training dataset, on the assumption that the model wasn't actually
# trained on this data.
self
.
data
[
"validation"
]
=
self
.
data
[
"train"
]
del
self
.
data
[
"train"
]
return
self
.
dataset
[
"train"
]
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
num_fewshot
==
0
...
...
lm_eval/tasks/cbt.py
View file @
6caa0afd
...
...
@@ -13,9 +13,8 @@ used by the Recurrent Language Models described in the paper. See section 4.4.
Homepage: https://github.com/facebookresearch/ParlAI/tree/main/parlai/tasks/cbt
"""
import
numpy
as
np
from
lm_eval.base
import
rf
from
lm_eval.base
import
rf
,
Task
from
lm_eval.metrics
import
mean
from
.common
import
HFTask
_CITATION
=
"""
...
...
@@ -30,11 +29,30 @@ _CITATION = """
"""
class
CBTBase
(
HF
Task
):
class
CBTBase
(
Task
):
VERSION
=
0
DATASET_PATH
=
"cbt"
DATASET_NAME
=
None
def
has_training_docs
(
self
):
return
True
def
has_validation_docs
(
self
):
return
True
def
has_test_docs
(
self
):
return
True
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
test_docs
(
self
):
return
self
.
dataset
[
"test"
]
def
detokenize
(
self
,
text
):
text
=
text
.
replace
(
" '"
,
"'"
)
...
...
lm_eval/tasks/common.py
deleted
100644 → 0
View file @
7064d6b9
import
datasets
from
..base
import
Task
class
HFTask
(
Task
):
DATASET_PATH
=
None
DATASET_NAME
=
None
def
__init__
(
self
):
self
.
data
=
None
super
().
__init__
()
def
download
(
self
):
self
.
data
=
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)
def
has_training_docs
(
self
):
"""Whether the task has a training set"""
return
True
if
"train"
in
self
.
data
.
keys
()
else
False
def
has_validation_docs
(
self
):
"""Whether the task has a validation set"""
return
True
if
"validation"
in
self
.
data
.
keys
()
else
False
def
has_test_docs
(
self
):
"""Whether the task has a test set"""
return
True
if
"test"
in
self
.
data
.
keys
()
else
False
def
_convert_standard
(
self
,
doc
):
return
doc
def
training_docs
(
self
):
# Cache training for faster few-shot.
# If data is too large to fit in memory, override this method.
if
self
.
has_training_docs
():
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
map
(
self
.
_convert_standard
,
self
.
data
[
"train"
]))
return
self
.
_training_docs
def
validation_docs
(
self
):
if
self
.
has_validation_docs
():
return
map
(
self
.
_convert_standard
,
self
.
data
[
"validation"
])
def
test_docs
(
self
):
if
self
.
has_test_docs
():
return
map
(
self
.
_convert_standard
,
self
.
data
[
"test"
])
def
yesno
(
x
):
if
x
:
return
'yes'
else
:
return
'no'
lm_eval/tasks/coqa.py
View file @
6caa0afd
...
...
@@ -9,13 +9,11 @@ appear in a conversation.
Homepage: https://stanfordnlp.github.io/coqa/
"""
import
os
import
json
import
inspect
import
transformers.data.metrics.squad_metrics
as
squad_metrics
import
lm_eval.datasets.coqa.coqa
from
lm_eval.base
import
Task
,
rf
,
mean
from
..utils
import
sh
from
itertools
import
zip_longest
from
best_download
import
download_file
_CITATION
=
"""
...
...
@@ -32,15 +30,8 @@ _CITATION = """
class
CoQA
(
Task
):
VERSION
=
1
def
download
(
self
):
coqa_train_filepath
=
'data/coqa/coqa-train-v1.0.json'
coqa_dev_filepath
=
'data/coqa/coqa-dev-v1.0.json'
sh
(
"""mkdir -p data/coqa"""
)
download_file
(
"http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json"
,
local_file
=
coqa_train_filepath
,
expected_checksum
=
"b0fdb2bc1bd38dd3ca2ce5fa2ac3e02c6288ac914f241ac409a655ffb6619fa6"
)
download_file
(
"http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json"
,
local_file
=
coqa_dev_filepath
,
expected_checksum
=
"dfa367a9733ce53222918d0231d9b3bedc2b8ee831a2845f62dfc70701f2540a"
)
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
coqa
.
coqa
)
DATASET_NAME
=
None
def
has_training_docs
(
self
):
return
True
...
...
@@ -52,10 +43,10 @@ class CoQA(Task):
return
False
def
training_docs
(
self
):
return
json
.
load
(
open
(
'data/coqa/coqa-train-v1.0.json'
))[
'data'
]
return
self
.
dataset
[
"train"
]
def
validation_docs
(
self
):
return
json
.
load
(
open
(
'data/coqa/coqa-dev-v1.0.json'
))[
'data'
]
return
self
.
dataset
[
"validation"
]
def
test_docs
(
self
):
pass
...
...
@@ -64,9 +55,9 @@ class CoQA(Task):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text
=
doc
[
"story"
]
+
'
\n\n
'
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
]
,
doc
[
"answers
"
][:
-
1
]):
# omit target answer ai
question
=
f
"Q:
{
q
[
'input_text'
]
}
"
+
'
\n\n
'
answer
=
f
"A:
{
a
[
'input_text'
]
}
"
+
'
\n\n
'
if
a
is
not
None
else
"A:"
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
]
[
"input_text"
],
doc
[
"answers"
][
"input_text
"
][:
-
1
]):
# omit target answer ai
question
=
f
"Q:
{
q
}
\n\n
"
answer
=
f
"A:
{
a
}
\n\n
"
if
a
is
not
None
else
"A:"
doc_text
+=
question
+
answer
return
doc_text
...
...
@@ -74,13 +65,13 @@ class CoQA(Task):
def
get_answers
(
cls
,
doc
,
turn_id
):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers
=
[]
answer_forturn
=
doc
[
"answers"
][
turn_id
-
1
][
"input_text"
]
answer_forturn
=
doc
[
"answers"
][
"input_text"
][
turn_id
-
1
]
answers
.
append
(
answer_forturn
)
additional_answers
=
doc
.
get
(
"additional_answers"
)
if
additional_answers
:
for
key
in
additional_answers
:
additional_answer_for_turn
=
additional_answers
[
key
][
turn_id
-
1
][
"input_text"
]
additional_answer_for_turn
=
additional_answers
[
key
][
"input_text"
][
turn_id
-
1
]
if
additional_answer_for_turn
.
lower
()
not
in
map
(
str
.
lower
,
answers
):
answers
.
append
(
additional_answer_for_turn
)
return
answers
...
...
@@ -120,8 +111,8 @@ class CoQA(Task):
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
# Default to prediction of last turn.
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
])
raw_text
=
doc
[
'answers'
][
turnid
-
1
][
"input_text"
]
turnid
=
len
(
doc
[
"questions"
]
[
"input_text"
]
)
raw_text
=
doc
[
'answers'
][
"input_text"
]
[
turnid
-
1
]
return
" "
+
raw_text
def
construct_requests
(
self
,
doc
,
ctx
):
...
...
@@ -148,7 +139,7 @@ class CoQA(Task):
:param results:
The results of the requests created in construct_requests.
"""
turn_id
=
len
(
doc
[
"questions"
])
turn_id
=
len
(
doc
[
"questions"
]
[
"input_text"
]
)
gold_list
=
self
.
get_answers
(
doc
,
turn_id
)
pred
=
results
[
0
].
strip
().
split
(
'
\n
'
)[
0
]
...
...
lm_eval/tasks/drop.py
View file @
6caa0afd
...
...
@@ -12,16 +12,14 @@ Homepage: https://allenai.org/data/drop
Acknowledgement: This implementation is based on the official evaluation for `DROP`:
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
"""
import
json
import
inspect
import
numpy
as
np
import
re
import
string
from
best_download
import
download_file
import
lm_eval.datasets.drop.drop
from
scipy.optimize
import
linear_sum_assignment
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
from
pathlib
import
Path
from
zipfile
import
ZipFile
_CITATION
=
"""
...
...
@@ -41,18 +39,8 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class
DROP
(
Task
):
VERSION
=
1
DATASET_PATH
=
Path
(
"data/drop"
)
def
download
(
self
):
if
self
.
DATASET_PATH
.
exists
():
return
Path
.
mkdir
(
self
.
DATASET_PATH
,
parents
=
True
)
url
=
"https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip"
checksum
=
"39d2278a29fd729de301b111a45f434c24834f40df8f4ff116d864589e3249d6"
zip_path
=
self
.
DATASET_PATH
/
"drop_dataset.zip"
download_file
(
url
,
local_file
=
str
(
zip_path
),
expected_checksum
=
checksum
)
with
ZipFile
(
zip_path
,
"r"
)
as
zip
:
zip
.
extractall
(
self
.
DATASET_PATH
)
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
drop
.
drop
)
DATASET_NAME
=
None
def
has_training_docs
(
self
):
return
True
...
...
@@ -63,29 +51,46 @@ class DROP(Task):
def
has_test_docs
(
self
):
return
False
def
_load_docs
(
self
,
docs
):
for
doc
in
docs
:
for
qa
in
doc
[
"qa_pairs"
]:
yield
{
"id"
:
qa
[
"query_id"
],
"passage"
:
doc
[
"passage"
],
"question"
:
qa
[
"question"
],
"answers"
:
self
.
get_answers
(
qa
),
}
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
map
(
self
.
_process_doc
,
self
.
dataset
[
"train"
]))
return
self
.
_training_docs
def
validation_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"validation"
])
def
_process_doc
(
self
,
doc
):
return
{
"id"
:
doc
[
"query_id"
],
"passage"
:
doc
[
"passage"
],
"question"
:
doc
[
"question"
],
"answers"
:
self
.
get_answers
(
doc
),
}
@
classmethod
def
get_answers
(
cls
,
qa
):
def
_flatten_validated_answers
(
validated_answers
):
""" Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas
=
[]
for
i
in
range
(
len
(
validated_answers
[
"number"
])):
vas
.
append
({
"number"
:
validated_answers
[
"number"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
})
return
vas
answers
=
[]
answers_set
=
set
()
candidates
=
[
qa
[
"answer"
]]
+
qa
.
get
(
"validated_answers"
,
[])
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
])
for
candidate
in
candidates
:
answer
=
cls
.
parse_answer
(
candidate
)
if
answer
in
answers_set
:
continue
answers_set
.
add
(
answer
)
answers
.
append
(
answer
)
return
answers
@
classmethod
...
...
@@ -99,14 +104,6 @@ class DROP(Task):
answer
[
"date"
][
"month"
],
answer
[
"date"
][
"year"
]]).
strip
(),)
def
training_docs
(
self
):
docs
=
json
.
load
(
open
(
self
.
DATASET_PATH
/
"drop_dataset"
/
"drop_dataset_train.json"
))
return
self
.
_load_docs
([
docs
[
k
]
for
k
in
docs
.
keys
()])
def
validation_docs
(
self
):
docs
=
json
.
load
(
open
(
self
.
DATASET_PATH
/
"drop_dataset"
/
"drop_dataset_dev.json"
))
return
self
.
_load_docs
([
docs
[
k
]
for
k
in
docs
.
keys
()])
def
doc_to_text
(
self
,
doc
):
return
f
"Passage:
{
doc
[
'passage'
]
}
\n
Question:
{
doc
[
'question'
]
}
\n
Answer:"
...
...
lm_eval/tasks/glue.py
View file @
6caa0afd
...
...
@@ -14,10 +14,9 @@ respect to a wide range of linguistic phenomena found in natural language.
Homepage: https://gluebenchmark.com/
"""
import
numpy
as
np
from
lm_eval.base
import
rf
from
..metrics
import
mean
,
matthews_corrcoef
,
f1_score
from
.
common
import
HFTask
,
yesno
from
..utils
import
general_detokenize
from
lm_eval.base
import
rf
,
Task
from
lm_eval.metrics
import
mean
,
matthews_corrcoef
,
f1_score
,
yesno
from
lm_eval.utils
import
general_detokenize
# TODO(jon-tow): Add citations for the individual datasets/tasks that make up GLUE.
...
...
@@ -46,7 +45,7 @@ _CITATION = """
# Single-Sentence Tasks
class
CoLA
(
HF
Task
):
class
CoLA
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"cola"
...
...
@@ -60,6 +59,14 @@ class CoLA(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: Does this sentence make sense?
\n
Answer:"
.
format
(
doc
[
"sentence"
])
...
...
@@ -90,7 +97,7 @@ class CoLA(HFTask):
}
class
SST
(
HF
Task
):
class
SST
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"sst2"
...
...
@@ -104,6 +111,14 @@ class SST(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: Is this sentence positive or negative?
\n
Answer:"
.
format
(
general_detokenize
(
doc
[
"sentence"
]),
...
...
@@ -139,7 +154,7 @@ class SST(HFTask):
# Inference Tasks
class
MNLI
(
HF
Task
):
class
MNLI
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"mnli"
...
...
@@ -153,13 +168,18 @@ class MNLI(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
if
self
.
has_validation_docs
():
return
self
.
data
[
"validation_matched"
]
return
self
.
data
set
[
"validation_matched"
]
def
test_docs
(
self
):
if
self
.
has_test_docs
():
return
self
.
data
[
"test_matched"
]
return
self
.
data
set
[
"test_matched"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: {} True, False or Neither?
\n
Answer:"
.
format
(
...
...
@@ -202,14 +222,14 @@ class MNLIMismatched(MNLI):
def
validation_docs
(
self
):
if
self
.
has_validation_docs
():
return
self
.
data
[
"validation_mismatched"
]
return
self
.
data
set
[
"validation_mismatched"
]
def
test_docs
(
self
):
if
self
.
has_test_docs
():
return
self
.
data
[
"test_mismatched"
]
return
self
.
data
set
[
"test_mismatched"
]
class
QNLI
(
HF
Task
):
class
QNLI
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"qnli"
...
...
@@ -223,6 +243,14 @@ class QNLI(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
{}
\n
Question: Does this response answer the question?
\n
Answer:"
.
format
(
doc
[
"question"
],
...
...
@@ -258,7 +286,7 @@ class QNLI(HFTask):
}
class
WNLI
(
HF
Task
):
class
WNLI
(
Task
):
VERSION
=
1
DATASET_PATH
=
"glue"
DATASET_NAME
=
"wnli"
...
...
@@ -272,6 +300,14 @@ class WNLI(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: {} True or False?
\n
Answer:"
.
format
(
doc
[
"sentence1"
],
...
...
@@ -307,7 +343,7 @@ class WNLI(HFTask):
}
class
RTE
(
HF
Task
):
class
RTE
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"rte"
...
...
@@ -321,6 +357,14 @@ class RTE(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: {} True or False?
\n
Answer:"
.
format
(
doc
[
"sentence1"
],
...
...
@@ -359,7 +403,7 @@ class RTE(HFTask):
# Similarity and Paraphrase Tasks
class
MRPC
(
HF
Task
):
class
MRPC
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"mrpc"
...
...
@@ -373,6 +417,14 @@ class MRPC(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"Sentence 1: {}
\n
Sentence 2: {}
\n
Question: Do both sentences mean the same thing?
\n
Answer:"
.
format
(
general_detokenize
(
doc
[
"sentence1"
]),
...
...
@@ -409,7 +461,7 @@ class MRPC(HFTask):
}
class
QQP
(
HF
Task
):
class
QQP
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"qqp"
...
...
@@ -423,6 +475,14 @@ class QQP(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"Question 1: {}
\n
Question 2: {}
\n
Question: Do both questions ask the same thing?
\n
Answer:"
.
format
(
doc
[
"question1"
],
...
...
@@ -459,7 +519,7 @@ class QQP(HFTask):
}
class
STSB
(
HF
Task
):
class
STSB
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"stsb"
...
...
@@ -473,6 +533,17 @@ class STSB(HFTask):
def
has_test_docs
(
self
):
return
True
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
test_docs
(
self
):
return
self
.
dataset
[
"test"
]
def
doc_to_text
(
self
,
doc
):
return
"sentence 1: {}
\n
sentence 2: {}
\n
Answer:"
.
format
(
doc
[
"sentence1"
],
...
...
lm_eval/tasks/gsm8k.py
View file @
6caa0afd
...
...
@@ -16,10 +16,9 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math
"""
import
json
import
inspect
import
re
from
best_download
import
download_file
import
lm_eval.datasets.gsm8k.gsm8k
from
pathlib
import
Path
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
...
...
@@ -43,21 +42,8 @@ INVALID_ANS = "[invalid]"
class
GradeSchoolMath8K
(
Task
):
VERSION
=
0
DATASET_PATH
=
Path
(
'data/gsm8k'
)
def
download
(
self
):
if
self
.
DATASET_PATH
.
exists
():
return
Path
.
mkdir
(
self
.
DATASET_PATH
,
parents
=
True
)
base_url
=
"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data"
splits
=
[
{
"name"
:
"train"
,
"checksum"
:
"17f347dc51477c50d4efb83959dbb7c56297aba886e5544ee2aaed3024813465"
},
{
"name"
:
"test"
,
"checksum"
:
"3730d312f6e3440559ace48831e51066acaca737f6eabec99bccb9e4b3c39d14"
},
]
for
split
in
splits
:
file
=
self
.
DATASET_PATH
/
f
"
{
split
[
'name'
]
}
.jsonl"
url
=
f
"
{
base_url
}
/
{
split
[
'name'
]
}
.jsonl"
download_file
(
url
,
local_file
=
str
(
file
),
expected_checksum
=
split
[
"checksum"
])
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
gsm8k
.
gsm8k
)
DATASET_NAME
=
None
def
has_training_docs
(
self
):
return
True
...
...
@@ -68,17 +54,14 @@ class GradeSchoolMath8K(Task):
def
has_test_docs
(
self
):
return
True
def
_load_docs
(
self
,
file
):
return
(
json
.
loads
(
line
)
for
line
in
open
(
file
).
read
().
splitlines
())
def
training_docs
(
self
):
return
self
.
_load_docs
(
self
.
DATASET_PATH
/
"train.jsonl"
)
return
self
.
dataset
[
"train"
]
def
validation_docs
(
self
):
raise
NotImplementedError
def
test_docs
(
self
):
return
self
.
_load_docs
(
self
.
DATASET_PATH
/
"test.jsonl"
)
return
self
.
dataset
[
"test"
]
def
doc_to_text
(
self
,
doc
):
return
"Question: "
+
doc
[
'question'
]
+
'
\n
Answer:'
...
...
lm_eval/tasks/headqa.py
View file @
6caa0afd
...
...
@@ -8,7 +8,8 @@ even for highly specialized humans.
Homepage: https://aghie.github.io/head-qa/
"""
from
.
common
import
HFTask
import
inspect
import
lm_eval.datasets.headqa.headqa
from
lm_eval.base
import
MultipleChoiceTask
...
...
@@ -24,9 +25,9 @@ _CITATION = """
"""
class
HeadQABase
(
HFTask
,
MultipleChoiceTask
):
class
HeadQABase
(
MultipleChoiceTask
):
VERSION
=
0
DATASET_PATH
=
"head_
qa
"
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
headqa
.
head
qa
)
def
has_training_docs
(
self
):
return
True
...
...
@@ -37,7 +38,18 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def
has_test_docs
(
self
):
return
True
def
_convert_standard
(
self
,
doc
):
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
map
(
self
.
_process_doc
,
self
.
dataset
[
"train"
]))
return
self
.
_training_docs
def
validation_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"validation"
])
def
test_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"test"
])
def
_process_doc
(
self
,
doc
):
out_doc
=
{
"id"
:
doc
[
"qid"
],
"query"
:
"Question: "
+
doc
[
"qtext"
]
+
"
\n
Answer:"
,
...
...
@@ -49,12 +61,15 @@ class HeadQABase(HFTask, MultipleChoiceTask):
def
doc_to_text
(
self
,
doc
):
return
doc
[
"query"
]
class
HeadQAEn
(
HeadQABase
):
DATASET_NAME
=
"en"
class
HeadQAEs
(
HeadQABase
):
DATASET_NAME
=
"es"
# for backwards compatibility
class
HeadQAEsDeprecated
(
HeadQABase
):
DATASET_NAME
=
"es"
...
...
lm_eval/tasks/hellaswag.py
View file @
6caa0afd
...
...
@@ -15,7 +15,6 @@ Homepage: https://rowanzellers.com/hellaswag/
"""
import
re
from
lm_eval.base
import
MultipleChoiceTask
from
.
common
import
HFTask
_CITATION
=
"""
...
...
@@ -28,7 +27,7 @@ _CITATION = """
"""
class
HellaSwag
(
HFTask
,
MultipleChoiceTask
):
class
HellaSwag
(
MultipleChoiceTask
):
VERSION
=
0
DATASET_PATH
=
"hellaswag"
DATASET_NAME
=
None
...
...
@@ -42,16 +41,15 @@ class HellaSwag(HFTask, MultipleChoiceTask):
def
has_test_docs
(
self
):
return
False
@
classmethod
def
preprocess
(
cls
,
text
):
text
=
text
.
strip
()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text
=
text
.
replace
(
" [title]"
,
". "
)
text
=
re
.
sub
(
'
\\
[.*?
\\
]'
,
''
,
text
)
text
=
text
.
replace
(
" "
,
" "
)
return
text
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
map
(
self
.
_process_doc
,
self
.
dataset
[
"train"
]))
return
self
.
_training_docs
def
validation_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"validation"
])
def
_
convert_standard
(
self
,
doc
):
def
_
process_doc
(
self
,
doc
):
ctx
=
doc
[
"ctx_a"
]
+
" "
+
doc
[
"ctx_b"
].
capitalize
()
out_doc
=
{
"query"
:
self
.
preprocess
(
doc
[
'activity_label'
]
+
': '
+
ctx
),
...
...
@@ -60,5 +58,14 @@ class HellaSwag(HFTask, MultipleChoiceTask):
}
return
out_doc
@
classmethod
def
preprocess
(
cls
,
text
):
text
=
text
.
strip
()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text
=
text
.
replace
(
" [title]"
,
". "
)
text
=
re
.
sub
(
'
\\
[.*?
\\
]'
,
''
,
text
)
text
=
text
.
replace
(
" "
,
" "
)
return
text
def
doc_to_text
(
self
,
doc
):
return
doc
[
"query"
]
lm_eval/tasks/hendrycks_ethics.py
View file @
6caa0afd
...
...
@@ -14,17 +14,14 @@ tasks are refered to in this work as the `em` sub-metric. See Section 3. Metrics
of the paper.
Homepage: https://github.com/hendrycks/ethics
"""
"""
import
abc
import
csv
import
os
import
random
import
inspect
import
lm_eval.datasets.hendrycks_ethics.hendrycks_ethics
import
numpy
as
np
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
from
lm_eval.utils
import
sh
from
.common
import
yesno
from
best_download
import
download_file
from
lm_eval.metrics
import
mean
,
yesno
_CITATION
=
"""
...
...
@@ -38,15 +35,8 @@ _CITATION = """
class
Ethics
(
Task
):
def
download
(
self
):
if
not
os
.
path
.
exists
(
'data/ethics/done'
):
sh
(
"mkdir -p data"
)
download_file
(
"https://people.eecs.berkeley.edu/~hendrycks/ethics.tar"
,
local_file
=
"data/ethics.tar"
,
expected_checksum
=
"40acbf1ac0da79a2aabef394d58889136b8d38b05be09482006de2453fb06333"
)
sh
(
"""
tar -xf data/ethics.tar -C data/
rm data/ethics.tar
touch data/ethics/done
"""
)
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
hendrycks_ethics
.
hendrycks_ethics
)
DATASET_NAME
=
None
def
has_training_docs
(
self
):
return
True
...
...
@@ -57,30 +47,16 @@ class Ethics(Task):
def
has_test_docs
(
self
):
return
True
@
abc
.
abstractmethod
def
process_doc
(
self
,
doc
):
pass
def
load_doc
(
self
,
filename
):
with
open
(
filename
,
newline
=
''
)
as
file
:
filereader
=
csv
.
reader
(
file
)
return
self
.
process_doc
(
list
(
filereader
))
@
abc
.
abstractmethod
def
get_prefix
(
self
):
"""returns string corresponding to file prefix"""
pass
# TODO: Figure out how to incorporate the Ethics `hard` test sets.
def
training_docs
(
self
):
return
self
.
load_doc
(
f
"data/ethics/
{
self
.
get_prefix
()
}
_train.csv"
)
return
self
.
dataset
[
"train"
]
def
validation_docs
(
self
):
raise
NotImplementedError
def
test_docs
(
self
):
return
self
.
load_doc
(
f
"data/ethics/
{
self
.
get_prefix
()
}
_test.csv"
)
return
self
.
dataset
[
"test"
]
@
abc
.
abstractmethod
def
doc_to_text
(
self
,
doc
):
...
...
@@ -109,18 +85,13 @@ class Ethics(Task):
class
EthicsCM
(
Ethics
):
VERSION
=
0
# Ignoring "ambiguous" extra dataset for now
def
get_prefix
(
self
):
return
"commonsense/cm"
def
process_doc
(
self
,
doc
):
return
doc
[
1
:]
DATASET_NAME
=
"commonsense"
# Ignoring "ambiguous" extra dataset for now
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: Is this wrong?
\n
Answer:"
.
format
(
doc
[
1
])
return
"{}
\n
Question: Is this wrong?
\n
Answer:"
.
format
(
doc
[
"input"
])
def
doc_to_target
(
self
,
doc
):
return
" {}"
.
format
(
yesno
(
int
(
doc
[
0
])))
return
" {}"
.
format
(
yesno
(
int
(
doc
[
"label"
])))
def
construct_requests
(
self
,
doc
,
ctx
):
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
" yes"
)
...
...
@@ -130,7 +101,7 @@ class EthicsCM(Ethics):
def
process_results
(
self
,
doc
,
results
):
ll_yes
,
ll_no
=
results
pred
=
ll_yes
>
ll_no
gold
=
bool
(
int
(
doc
[
0
]))
gold
=
bool
(
int
(
doc
[
"label"
]))
return
{
"acc"
:
pred
==
gold
}
...
...
@@ -148,19 +119,14 @@ class EthicsCM(Ethics):
class
EthicsDeontology
(
Ethics
):
VERSION
=
0
def
get_prefix
(
self
):
return
"deontology/deontology"
def
process_doc
(
self
,
doc
):
# Append identifiers before shuffling to calculate exact matches lateron & skip the first element of headers
return
[
x
+
[
i
]
for
i
,
x
in
enumerate
(
doc
[
1
:])]
DATASET_NAME
=
"deontology"
def
doc_to_text
(
self
,
doc
):
prompt
=
" "
.
join
([
doc
[
1
],
doc
[
2
]])
prompt
=
" "
.
join
([
doc
[
"scenario"
],
doc
[
"excuse"
]])
return
"Question: Would most people believe this reasonable or unreasonable to say?
\"
{}
\"\n
Answer:"
.
format
(
prompt
)
def
doc_to_target
(
self
,
doc
):
target
=
[
"unreasonable"
,
"reasonable"
][
int
(
doc
[
0
])]
target
=
[
"unreasonable"
,
"reasonable"
][
int
(
doc
[
"label"
])]
return
" {}"
.
format
(
target
)
def
construct_requests
(
self
,
doc
,
ctx
):
...
...
@@ -170,14 +136,15 @@ class EthicsDeontology(Ethics):
def
process_results
(
self
,
doc
,
results
):
pred
=
np
.
argmax
(
results
)
gold
=
bool
(
int
(
doc
[
0
]))
gold
=
bool
(
int
(
doc
[
"label"
]))
return
{
"acc"
:
pred
==
gold
,
"em"
:
[
doc
[
-
1
],
pred
==
gold
]
"em"
:
[
doc
[
"group_id"
],
pred
==
gold
]
}
def
calc_em
(
self
,
items
):
# Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort
=
sorted
(
items
,
key
=
lambda
x
:
x
[
0
])
em_sums
=
[
int
(
preds_sort
[
4
*
i
][
1
])
+
int
(
preds_sort
[
4
*
i
+
1
][
1
])
+
int
(
preds_sort
[
4
*
i
+
2
][
1
])
+
int
(
preds_sort
[
4
*
i
+
3
][
1
])
for
i
in
range
(
len
(
preds_sort
)
//
4
)]
em_cors
=
[
em_sums
[
i
]
==
4
for
i
in
range
(
len
(
em_sums
))]
...
...
@@ -198,18 +165,13 @@ class EthicsDeontology(Ethics):
class
EthicsJustice
(
Ethics
):
VERSION
=
0
def
get_prefix
(
self
):
return
"justice/justice"
def
process_doc
(
self
,
doc
):
# Append identifiers before shuffling to calculate exact matches later on & skip the first element of headers
return
[
x
+
[
i
]
for
i
,
x
in
enumerate
(
doc
[
1
:])]
DATASET_NAME
=
"justice"
def
doc_to_text
(
self
,
doc
):
return
"Question: Would most people believe this reasonable or unreasonable to say?
\"
{}
\"\n
Answer:"
.
format
(
doc
[
1
])
return
"Question: Would most people believe this reasonable or unreasonable to say?
\"
{}
\"\n
Answer:"
.
format
(
doc
[
"scenario"
])
def
doc_to_target
(
self
,
doc
):
target
=
[
"unreasonable"
,
"reasonable"
][
int
(
doc
[
0
])]
target
=
[
"unreasonable"
,
"reasonable"
][
int
(
doc
[
"label"
])]
return
" {}"
.
format
(
target
)
def
construct_requests
(
self
,
doc
,
ctx
):
...
...
@@ -219,14 +181,15 @@ class EthicsJustice(Ethics):
def
process_results
(
self
,
doc
,
results
):
pred
=
np
.
argmax
(
results
)
gold
=
bool
(
int
(
doc
[
0
]))
gold
=
bool
(
int
(
doc
[
"label"
]))
return
{
"acc"
:
pred
==
gold
,
"em"
:
[
doc
[
-
1
],
pred
==
gold
]
"em"
:
[
doc
[
"group_id"
],
pred
==
gold
]
}
def
calc_em
(
self
,
items
):
# Calculate exact matches - i.e. all in a pair of 4 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort
=
sorted
(
items
,
key
=
lambda
x
:
x
[
0
])
em_sums
=
[
int
(
preds_sort
[
4
*
i
][
1
])
+
int
(
preds_sort
[
4
*
i
+
1
][
1
])
+
int
(
preds_sort
[
4
*
i
+
2
][
1
])
+
int
(
preds_sort
[
4
*
i
+
3
][
1
])
for
i
in
range
(
len
(
preds_sort
)
//
4
)]
em_cors
=
[
em_sums
[
i
]
==
4
for
i
in
range
(
len
(
em_sums
))]
...
...
@@ -247,17 +210,12 @@ class EthicsJustice(Ethics):
class
EthicsUtilitarianismOriginal
(
Ethics
):
VERSION
=
0
def
get_prefix
(
self
):
return
"utilitarianism/util"
DATASET_NAME
=
"utilitarianism"
def
has_training_docs
(
self
):
# Rely on the fixed and labeled examples of `fewshot_examples` for the few-shot setting.
return
False
def
process_doc
(
self
,
docs
):
for
doc
in
docs
:
yield
{
"activity"
:
doc
[
0
],
"baseline"
:
doc
[
1
],
"rating"
:
""
}
def
fewshot_examples
(
self
,
k
,
rnd
):
# Overwriting fewshot examples as k can be max 5
assert
k
<=
5
,
"There are only 5 possible shots for this task. Refer to the V2 for more."
...
...
@@ -311,25 +269,36 @@ class EthicsUtilitarianismOriginal(Ethics):
class
EthicsUtilitarianism
(
Ethics
):
VERSION
=
0
"""
This is a variation of the original Utilitarianism task used in the paper, where the situations are directly compared.
This allows scaling to >5 shots.
"""
VERSION
=
0
DATASET_NAME
=
"utilitarianism"
def
training_docs
(
self
):
rnd
=
random
.
Random
()
for
doc
in
self
.
dataset
[
"train"
]:
yield
self
.
_process_doc
(
doc
,
rnd
)
def
get_prefix
(
self
):
r
eturn
"utilitarianism/util"
def
validation_docs
(
self
):
r
aise
NotImplementedError
def
proc
es
s
_doc
(
self
,
docs
):
def
t
es
t
_doc
s
(
self
):
rnd
=
random
.
Random
()
for
doc
in
docs
:
rnd
.
seed
(
doc
[
0
])
ordering
=
[
0
,
1
]
rnd
.
shuffle
(
ordering
)
yield
{
"scenarios"
:
[
doc
[
ordering
[
0
]],
doc
[
ordering
[
1
]]],
"label"
:
int
(
ordering
.
index
(
0
)
==
0
),
# The correct scenario is always first
}
for
doc
in
self
.
dataset
[
"test"
]:
yield
self
.
_process_doc
(
doc
,
rnd
)
def
_process_doc
(
self
,
doc
,
rnd
):
rnd
.
seed
(
doc
[
"activity"
])
scenarios
=
[
doc
[
"activity"
],
doc
[
"baseline"
]]
ordering
=
[
0
,
1
]
rnd
.
shuffle
(
ordering
)
return
{
"scenarios"
:
[
scenarios
[
ordering
[
0
]],
scenarios
[
ordering
[
1
]]],
# The correct scenario is always first
"label"
:
int
(
ordering
.
index
(
0
)
==
0
),
}
def
doc_to_text
(
self
,
doc
):
return
"Scenario 1: {}
\n
Scenario 2: {}
\n
Question: Is Scenario 1 preferrable?
\n
Answer:"
.
format
(
...
...
@@ -365,23 +334,19 @@ class EthicsUtilitarianism(Ethics):
class
EthicsVirtue
(
Ethics
):
VERSION
=
0
def
get_prefix
(
self
):
return
"virtue/virtue"
def
process_doc
(
self
,
doc
):
# Append identifiers before shuffling to calculate exact matches lateron & skip the first element of headers
return
[
x
+
[
i
]
for
i
,
x
in
enumerate
(
doc
[
1
:])]
DATASET_NAME
=
"virtue"
def
load_doc
(
self
,
filename
):
with
open
(
filename
,
newline
=
''
)
as
file
:
filereader
=
csv
.
reader
(
file
)
return
self
.
process_doc
(
list
(
filereader
))
def
_process_doc
(
self
,
doc
):
return
doc
def
doc_to_text
(
self
,
doc
):
return
"Sentence: {}
\n
Question: Does the character in this sentence exhibit the trait
\"
{}
\"
?
\n
Answer:"
.
format
(
*
doc
[
1
].
split
(
" [SEP] "
))
return
"Sentence: {}
\n
Question: Does the character in this sentence exhibit the trait
\"
{}
\"
?
\n
Answer:"
.
format
(
doc
[
"scenario"
],
doc
[
"trait"
]
)
def
doc_to_target
(
self
,
doc
):
return
" {}"
.
format
(
yesno
(
int
(
doc
[
0
])))
return
" {}"
.
format
(
yesno
(
int
(
doc
[
"label"
])))
def
construct_requests
(
self
,
doc
,
ctx
):
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
" yes"
)
...
...
@@ -391,14 +356,15 @@ class EthicsVirtue(Ethics):
def
process_results
(
self
,
doc
,
results
):
ll_yes
,
ll_no
=
results
pred
=
ll_yes
>
ll_no
gold
=
bool
(
int
(
doc
[
0
]))
gold
=
bool
(
int
(
doc
[
"label"
]))
return
{
"acc"
:
pred
==
gold
,
"em"
:
[
doc
[
-
1
],
pred
==
gold
]
"em"
:
[
doc
[
"group_id"
],
pred
==
gold
]
}
def
calc_em
(
self
,
items
):
# Calculate exact matches - i.e. all in a pair of 5 are correct
# NOTE: `items` is a tuple of (doc["group_id"], is_correct)
preds_sort
=
sorted
(
items
,
key
=
lambda
x
:
x
[
0
])
em_sums
=
[
int
(
preds_sort
[
5
*
i
][
1
])
+
int
(
preds_sort
[
5
*
i
+
1
][
1
])
+
int
(
preds_sort
[
5
*
i
+
2
][
1
])
+
int
(
preds_sort
[
5
*
i
+
3
][
1
])
+
int
(
preds_sort
[
5
*
i
+
4
][
1
])
for
i
in
range
(
len
(
preds_sort
)
//
5
)]
em_cors
=
[
em_sums
[
i
]
==
5
for
i
in
range
(
len
(
em_sums
))]
...
...
lm_eval/tasks/hendrycks_math.py
View file @
6caa0afd
...
...
@@ -8,13 +8,10 @@ models to generate answer derivations and explanations.
Homepage: https://github.com/hendrycks/math
"""
import
abc
import
json
from
lm_eval.utils
import
sh
import
inspect
import
lm_eval.datasets.hendrycks_math.hendrycks_math
from
lm_eval.metrics
import
mean
from
lm_eval.base
import
Task
,
rf
from
pathlib
import
Path
from
best_download
import
download_file
_CITATION
=
"""
...
...
@@ -28,21 +25,8 @@ _CITATION = """
class
Math
(
Task
):
DATASET_PATH
=
Path
(
'data/MATH'
)
def
download
(
self
):
if
not
(
self
.
DATASET_PATH
/
'test'
).
exists
()
or
not
(
self
.
DATASET_PATH
/
'done'
).
exists
():
sh
(
f
"mkdir -p
{
self
.
DATASET_PATH
}
"
)
download_file
(
"https://people.eecs.berkeley.edu/~hendrycks/MATH.tar"
,
local_file
=
f
"
{
self
.
DATASET_PATH
}
.tar"
,
expected_checksum
=
"0fbe4fad0df66942db6c221cdcc95b298cc7f4595a2f0f518360cce84e90d9ac"
)
sh
(
f
"""
tar -xf
{
self
.
DATASET_PATH
}
.tar -C data/ && touch
{
self
.
DATASET_PATH
/
'done'
}
rm
{
self
.
DATASET_PATH
}
.tar
"""
)
@
abc
.
abstractmethod
def
get_file_info
(
self
):
"""returns directory name"""
pass
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
hendrycks_math
.
hendrycks_math
)
DATASET_NAME
=
None
def
has_training_docs
(
self
):
return
True
...
...
@@ -53,28 +37,25 @@ class Math(Task):
def
has_test_docs
(
self
):
return
True
def
_load_docs
(
self
,
path
):
for
file
in
sorted
(
path
.
iterdir
()):
with
open
(
file
)
as
f
:
doc
=
json
.
load
(
f
)
doc
[
"answer"
]
=
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
doc
[
"solution"
]))
yield
doc
def
training_docs
(
self
):
return
self
.
_load_doc
s
(
self
.
DATASET_PATH
/
"train"
/
self
.
get_file_info
()
)
return
map
(
self
.
_load_doc
,
self
.
dataset
[
"train"
]
)
def
validation_docs
(
self
):
return
NotImplemented
def
test_docs
(
self
):
return
self
.
_load_docs
(
self
.
DATASET_PATH
/
"test"
/
self
.
get_file_info
())
return
map
(
self
.
_load_doc
,
self
.
dataset
[
"test"
])
def
_load_doc
(
self
,
doc
):
doc
[
"answer"
]
=
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
doc
[
"solution"
]))
return
doc
def
doc_to_text
(
self
,
doc
):
return
"Problem: "
+
doc
[
"problem"
]
+
"
\n
Answer:"
def
doc_to_target
(
self
,
doc
):
return
" "
+
doc
[
"
answer
"
]
return
" "
+
doc
[
"
solution
"
]
def
construct_requests
(
self
,
doc
,
ctx
):
return
rf
.
greedy_until
(
ctx
,
[
"
\n
"
])
...
...
@@ -301,41 +282,34 @@ class Math(Task):
class
MathAlgebra
(
Math
):
VERSION
=
1
def
get_file_info
(
self
):
return
'algebra'
DATASET_NAME
=
'algebra'
class
MathCountingAndProbability
(
Math
):
VERSION
=
1
def
get_file_info
(
self
):
return
'counting_and_probability'
DATASET_NAME
=
'counting_and_probability'
class
MathGeometry
(
Math
):
VERSION
=
1
def
get_file_info
(
self
):
return
'geometry'
DATASET_NAME
=
'geometry'
class
MathIntermediateAlgebra
(
Math
):
VERSION
=
1
def
get_file_info
(
self
):
return
'intermediate_algebra'
DATASET_NAME
=
'intermediate_algebra'
class
MathNumberTheory
(
Math
):
VERSION
=
1
def
get_file_info
(
self
):
return
'number_theory'
DATASET_NAME
=
'number_theory'
class
MathPrealgebra
(
Math
):
VERSION
=
1
def
get_file_info
(
self
):
return
'prealgebra'
DATASET_NAME
=
'prealgebra'
class
MathPrecalculus
(
Math
):
VERSION
=
1
def
get_file_info
(
self
):
return
'precalculus'
DATASET_NAME
=
'precalculus'
lm_eval/tasks/hendrycks_test.py
View file @
6caa0afd
...
...
@@ -12,12 +12,7 @@ important shortcomings.
Homepage: https://github.com/hendrycks/test
"""
import
csv
import
random
from
lm_eval.base
import
MultipleChoiceTask
from
..utils
import
sh
from
pathlib
import
Path
from
best_download
import
download_file
_CITATION
=
"""
...
...
@@ -61,25 +56,15 @@ def create_task(subject):
class
GeneralHendrycksTest
(
MultipleChoiceTask
):
VERSION
=
0
DATASET_PATH
=
Path
(
"data/hendrycksTest/"
)
DATASET_PATH
=
"hendrycks_test"
DATASET_NAME
=
None
def
__init__
(
self
,
subject
):
self
.
subject
=
subject
self
.
DATASET_NAME
=
subject
super
().
__init__
()
def
download
(
self
):
if
not
(
self
.
DATASET_PATH
/
'done'
).
exists
():
sh
(
"mkdir -p data"
)
download_file
(
"https://people.eecs.berkeley.edu/~hendrycks/data.tar"
,
local_file
=
"data/data.tar"
,
expected_checksum
=
"78a804365a59028188fb19bd1adcadc5e0c260b220a9d8b2e33a5ea7d5fbe3b4"
)
sh
(
"""
tar -xf data/data.tar -C data/
rm data/data.tar
mv data/data data/hendrycksTest
touch data/hendrycksTest/done
"""
)
def
has_training_docs
(
self
):
return
Tru
e
return
Fals
e
def
has_validation_docs
(
self
):
return
True
...
...
@@ -87,8 +72,14 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def
has_test_docs
(
self
):
return
True
def
_convert_standard
(
self
,
doc
):
def
format_example
(
doc
,
choices
):
def
validation_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"validation"
])
def
test_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"test"
])
def
_process_doc
(
self
,
doc
):
def
format_example
(
doc
,
keys
):
"""
Question: <prompt>
Choices:
...
...
@@ -98,44 +89,23 @@ class GeneralHendrycksTest(MultipleChoiceTask):
D. <choice4>
Answer:
"""
prompt
=
"Question: "
+
doc
[
0
]
+
"
\n
Choices:
\n
"
prompt
+=
""
.
join
([
f
"
{
choices
[
j
]
}
.
{
doc
[
j
+
1
]
}
\n
"
for
j
in
range
(
4
)])
prompt
=
"Question: "
+
doc
[
"question"
]
+
"
\n
Choices:
\n
"
prompt
+=
""
.
join
([
f
"
{
key
}
.
{
choice
}
\n
"
for
key
,
choice
in
zip
(
keys
,
doc
[
"choices"
]
)])
prompt
+=
"Answer:"
return
prompt
choice
s
=
[
'A'
,
'B'
,
'C'
,
'D'
]
key
s
=
[
'A'
,
'B'
,
'C'
,
'D'
]
return
{
"query"
:
format_example
(
doc
,
choice
s
),
"choices"
:
doc
[
1
:
5
],
"gold"
:
choice
s
.
index
(
doc
[
5
])
"query"
:
format_example
(
doc
,
key
s
),
"choices"
:
doc
[
"choices"
],
"gold"
:
key
s
.
index
(
doc
[
"answer"
])
if
isinstance
(
doc
[
"answer"
],
str
)
else
doc
[
"answer"
]
}
def
_load_docs
(
self
,
filename
):
reader
=
csv
.
reader
(
open
(
filename
,
'r'
),
quotechar
=
'"'
,
delimiter
=
','
)
return
(
self
.
_convert_standard
(
doc
)
for
doc
in
reader
)
def
training_docs
(
self
):
docs
=
[]
for
train_dir
in
[
"auxiliary_train"
,
"dev"
]:
for
f
in
(
self
.
DATASET_PATH
/
train_dir
).
iterdir
():
docs
.
extend
(
self
.
_load_docs
(
f
))
return
docs
def
validation_docs
(
self
):
filename
=
self
.
DATASET_PATH
/
"val"
/
f
"
{
self
.
subject
}
_val.csv"
return
self
.
_load_docs
(
filename
)
def
test_docs
(
self
):
filename
=
self
.
DATASET_PATH
/
"test"
/
f
"
{
self
.
subject
}
_test.csv"
return
self
.
_load_docs
(
filename
)
def
fewshot_examples
(
self
,
k
,
rnd
):
# fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't
filename
=
self
.
DATASET_PATH
/
"dev"
/
f
"
{
self
.
subject
}
_dev.csv"
if
self
.
_fewshot_docs
is
None
:
self
.
_fewshot_docs
=
list
(
self
.
_
load_docs
(
filename
))
self
.
_fewshot_docs
=
list
(
map
(
self
.
_
process_doc
,
self
.
dataset
[
"dev"
]
))
return
rnd
.
sample
(
list
(
self
.
_fewshot_docs
),
k
)
...
...
lm_eval/tasks/lambada.py
View file @
6caa0afd
...
...
@@ -12,12 +12,10 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
import
json
import
inspect
import
lm_eval.datasets.lambada.lambada
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
,
perplexity
from
lm_eval.utils
import
sh
from
best_download
import
download_file
import
os
_CITATION
=
"""
...
...
@@ -34,19 +32,7 @@ _CITATION = """
class
LAMBADA
(
Task
):
VERSION
=
0
def
download
(
self
):
sh
(
"mkdir -p data/lambada"
)
try
:
if
not
os
.
path
.
exists
(
"data/lambada/lambada_test.jsonl"
):
download_file
(
"http://eaidata.bmk.sh/data/lambada_test.jsonl"
,
local_file
=
"data/lambada/lambada_test.jsonl"
,
expected_checksum
=
"4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"
)
except
:
# fallback - for some reason best_download doesnt work all the time here
sh
(
"wget http://eaidata.bmk.sh/data/lambada_test.jsonl -O data/lambada/lambada_test.jsonl"
)
sh
(
'echo "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226 data/lambada/lambada_test.jsonl" | sha256sum --check'
)
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
lambada
.
lambada
)
def
has_training_docs
(
self
):
return
False
...
...
@@ -61,9 +47,7 @@ class LAMBADA(Task):
pass
def
validation_docs
(
self
):
with
open
(
"data/lambada/lambada_test.jsonl"
)
as
fh
:
for
line
in
fh
:
yield
json
.
loads
(
line
)
return
self
.
dataset
[
"validation"
]
def
test_docs
(
self
):
pass
...
...
lm_eval/tasks/lambada_cloze.py
View file @
6caa0afd
...
...
@@ -13,12 +13,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
import
json
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
,
perplexity
from
lm_eval.utils
import
sh
from
lm_eval.tasks.lambada
import
LAMBADA
from
best_download
import
download_file
_CITATION
=
"""
...
...
@@ -35,6 +30,7 @@ _CITATION = """
class
LAMBADA_cloze
(
LAMBADA
):
VERSION
=
0
def
doc_to_text
(
self
,
doc
):
return
doc
[
'text'
].
rsplit
(
' '
,
1
)[
0
]
+
" ____. ->"
...
...
lm_eval/tasks/lambada_multilingual.py
View file @
6caa0afd
...
...
@@ -14,13 +14,6 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
from
.
import
lambada
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
,
perplexity
from
lm_eval.utils
import
sh
from
best_download
import
download_file
import
json
from
functools
import
partial
import
os
_CITATION
=
"""
...
...
@@ -35,68 +28,37 @@ _CITATION = """
"""
LANGS
=
[
"en"
,
"fr"
,
"de"
,
"it"
,
"es"
]
CHECKSUMS
=
{
"en"
:
"4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"
,
"fr"
:
"941ec6a73dba7dc91c860bf493eb66a527cd430148827a4753a4535a046bf362"
,
"de"
:
"51c6c1795894c46e88e4c104b5667f488efe79081fb34d746b82b8caa663865e"
,
"it"
:
"86654237716702ab74f42855ae5a78455c1b0e50054a4593fb9c6fcf7fad0850"
,
"es"
:
"ffd760026c647fb43c67ce1bc56fd527937304b348712dce33190ea6caba6f9c"
}
class
MultilingualLAMBADA
(
lambada
.
LAMBADA
):
VERSION
=
0
def
__init__
(
self
,
lang
=
None
):
self
.
LANG
=
lang
super
().
__init__
()
def
download
(
self
):
sh
(
"mkdir -p data/lambada"
)
f
=
f
"data/lambada/lambada_test_
{
self
.
LANG
}
.jsonl"
url
=
f
"http://eaidata.bmk.sh/data/lambada_test_
{
self
.
LANG
}
.jsonl"
try
:
if
not
os
.
path
.
exists
(
f
):
download_file
(
url
,
local_file
=
f
,
expected_checksum
=
CHECKSUMS
[
self
.
LANG
]
)
except
:
# fallback - for some reason best_download doesnt work all the time here
sh
(
f
"wget
{
url
}
-O
{
f
}
"
)
sh
(
f
'echo "
{
CHECKSUMS
[
self
.
LANG
]
}
{
f
}
" | sha256sum --check'
)
def
validation_docs
(
self
):
with
open
(
f
"data/lambada/lambada_test_
{
self
.
LANG
}
.jsonl"
)
as
fh
:
for
line
in
fh
:
yield
json
.
loads
(
line
)
class
MultilingualLAMBADAEN
(
MultilingualLAMBADA
):
def
__init__
(
self
):
super
().
__init__
(
'en'
)
DATASET_NAME
=
'en'
class
MultilingualLAMBADAFR
(
MultilingualLAMBADA
):
def
__init__
(
self
):
super
().
__init__
(
'fr'
)
DATASET_NAME
=
'fr'
class
MultilingualLAMBADADE
(
MultilingualLAMBADA
):
def
__init__
(
self
):
super
().
__init__
(
'de'
)
DATASET_NAME
=
'de'
class
MultilingualLAMBADAIT
(
MultilingualLAMBADA
):
def
__init__
(
self
):
super
().
__init__
(
'it'
)
DATASET_NAME
=
'it'
class
MultilingualLAMBADAES
(
MultilingualLAMBADA
):
def
__init__
(
self
):
super
().
__init__
(
'es'
)
DATASET_NAME
=
'es'
LANG_CLASSES
=
[
MultilingualLAMBADAEN
,
MultilingualLAMBADAFR
,
MultilingualLAMBADADE
,
MultilingualLAMBADAIT
,
MultilingualLAMBADAES
]
LANG_CLASSES
=
[
MultilingualLAMBADAEN
,
MultilingualLAMBADAFR
,
MultilingualLAMBADADE
,
MultilingualLAMBADAIT
,
MultilingualLAMBADAES
]
def
construct_tasks
():
tasks
=
{}
for
lang
,
lang_class
in
zip
(
LANGS
,
LANG_CLASSES
)
:
tasks
[
f
"lambada_mt_
{
lang
}
"
]
=
lang_class
for
lang_class
in
LANG_CLASSES
:
tasks
[
f
"lambada_mt_
{
lang
_class
.
DATASET_NAME
}
"
]
=
lang_class
return
tasks
lm_eval/tasks/logiqa.py
View file @
6caa0afd
...
...
@@ -10,9 +10,9 @@ NLP setting.
Homepage: https://github.com/lgw863/LogiQA-dataset
"""
import
inspect
import
lm_eval.datasets.logiqa.logiqa
from
lm_eval.base
import
MultipleChoiceTask
from
best_download
import
download_file
from
pathlib
import
Path
_CITATION
=
"""
...
...
@@ -29,21 +29,8 @@ _CITATION = """
class
LogiQA
(
MultipleChoiceTask
):
VERSION
=
0
DATASET_PATH
=
Path
(
"data/logiqa"
)
def
download
(
self
):
if
self
.
DATASET_PATH
.
exists
():
return
Path
.
mkdir
(
self
.
DATASET_PATH
,
parents
=
True
)
base_url
=
"https://raw.githubusercontent.com/lgw863/LogiQA-dataset/master"
splits
=
[
{
"name"
:
"Train"
,
"checksum"
:
"7d5bb1f58278e33b395744cd2ad8d7600faa0b3c4d615c659a44ec1181d759fa"
},
{
"name"
:
"Eval"
,
"checksum"
:
"4c49e6753b7262c001506b9151135abf722247035ab075dad93acdea5789c01f"
},
{
"name"
:
"Test"
,
"checksum"
:
"359acb78c37802208f7fde9e2f6574b8526527c63d6a336f90a53f1932cb4701"
}
]
for
split
in
splits
:
file
=
self
.
DATASET_PATH
/
f
"
{
split
[
'name'
]
}
.txt"
download_file
(
f
"
{
base_url
}
/
{
split
[
'name'
]
}
.txt"
,
local_file
=
str
(
file
),
expected_checksum
=
split
[
"checksum"
])
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
logiqa
.
logiqa
)
DATASET_NAME
=
None
def
has_training_docs
(
self
):
return
True
...
...
@@ -54,7 +41,18 @@ class LogiQA(MultipleChoiceTask):
def
has_test_docs
(
self
):
return
True
def
_convert_standard
(
self
,
doc
):
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
map
(
self
.
_process_doc
,
self
.
dataset
[
"train"
]))
return
self
.
_training_docs
def
validation_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"validation"
])
def
test_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"test"
])
def
_process_doc
(
self
,
doc
):
def
format_example
(
doc
,
choices
):
"""
Passage: <passage>
...
...
@@ -66,7 +64,7 @@ class LogiQA(MultipleChoiceTask):
D. <choice4>
Answer:
"""
prompt
=
"Passage: "
+
doc
[
"
passage
"
]
+
"
\n
"
prompt
=
"Passage: "
+
doc
[
"
context
"
]
+
"
\n
"
prompt
+=
"Question: "
+
doc
[
"question"
]
+
"
\n
Choices:
\n
"
for
choice
,
option
in
zip
(
choices
,
doc
[
"options"
]):
prompt
+=
f
"
{
choice
.
upper
()
}
.
{
option
}
\n
"
...
...
@@ -76,33 +74,8 @@ class LogiQA(MultipleChoiceTask):
return
{
"query"
:
format_example
(
doc
,
choices
),
"choices"
:
doc
[
"options"
],
"gold"
:
choices
.
index
(
doc
[
"
answerKey
"
])
"gold"
:
choices
.
index
(
doc
[
"
label
"
])
}
def
_load_docs
(
self
,
filename
):
def
normalize
(
text
):
return
text
.
replace
(
"."
,
". "
).
strip
()
with
open
(
filename
,
'r'
)
as
f
:
docs
=
f
.
read
().
strip
().
split
(
"
\n\n
"
)
for
rawdoc
in
docs
:
rawdoc
=
rawdoc
.
split
(
"
\n
"
)
doc
=
{
"answerKey"
:
rawdoc
[
0
].
strip
(),
"passage"
:
normalize
(
rawdoc
[
1
]),
"question"
:
normalize
(
rawdoc
[
2
]),
"options"
:
[
normalize
(
option
[
2
:])
for
option
in
rawdoc
[
3
:]]
}
yield
self
.
_convert_standard
(
doc
)
def
training_docs
(
self
):
return
self
.
_load_docs
(
self
.
DATASET_PATH
/
"Train.txt"
)
def
validation_docs
(
self
):
return
self
.
_load_docs
(
self
.
DATASET_PATH
/
"Eval.txt"
)
def
test_docs
(
self
):
return
self
.
_load_docs
(
self
.
DATASET_PATH
/
"Test.txt"
)
def
doc_to_text
(
self
,
doc
):
return
doc
[
"query"
]
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment