Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
6caa0afd
Unverified
Commit
6caa0afd
authored
Apr 01, 2022
by
Leo Gao
Committed by
GitHub
Apr 01, 2022
Browse files
Merge pull request #300 from jon-tow/hf-dataset-refactor
Refactor `Task` downloading to use `HuggingFace.datasets`
parents
7064d6b9
9434722c
Changes
87
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
97 additions
and
155 deletions
+97
-155
lm_eval/tasks/truthfulqa.py
lm_eval/tasks/truthfulqa.py
+28
-52
lm_eval/tasks/unscramble.py
lm_eval/tasks/unscramble.py
+10
-39
lm_eval/tasks/webqs.py
lm_eval/tasks/webqs.py
+11
-4
lm_eval/tasks/wikitext.py
lm_eval/tasks/wikitext.py
+16
-29
lm_eval/tasks/winogrande.py
lm_eval/tasks/winogrande.py
+11
-4
lm_eval/tasks/wsc273.py
lm_eval/tasks/wsc273.py
+19
-25
setup.py
setup.py
+2
-2
No files found.
lm_eval/tasks/truthfulqa.py
View file @
6caa0afd
...
@@ -19,16 +19,14 @@ we could try this?
...
@@ -19,16 +19,14 @@ we could try this?
Homepage: https://github.com/sylinrl/TruthfulQA
Homepage: https://github.com/sylinrl/TruthfulQA
"""
"""
import
csv
import
inspect
import
json
import
numpy
as
np
import
numpy
as
np
import
sacrebleu
import
sacrebleu
import
datasets
import
lm_eval.datasets.truthfulqa.truthfulqa
from
rouge_score
import
rouge_scorer
,
scoring
from
rouge_score
import
rouge_scorer
,
scoring
from
lm_eval.base
import
rf
,
Task
from
lm_eval.base
import
rf
,
Task
from
pathlib
import
Path
from
lm_eval.metrics
import
mean
from
best_download
import
download_file
from
..metrics
import
mean
from
datasets
import
load_metric
_CITATION
=
"""
_CITATION
=
"""
...
@@ -62,15 +60,8 @@ QA_PROMPT = (
...
@@ -62,15 +60,8 @@ QA_PROMPT = (
class
TruthfulQAMultipleChoice
(
Task
):
class
TruthfulQAMultipleChoice
(
Task
):
VERSION
=
1
VERSION
=
1
DATASET_PATH
=
Path
(
'data/truthfulqa/mc'
)
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
truthfulqa
.
truthfulqa
)
DATASET_NAME
=
"multiple_choice"
def
download
(
self
):
if
self
.
DATASET_PATH
.
exists
():
return
Path
.
mkdir
(
self
.
DATASET_PATH
,
parents
=
True
)
mc_url
=
"https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json"
checksum
=
"6eb4125d25750c0145c4be2dce00440736684ab6f74ce6bff2139571cc758954"
download_file
(
mc_url
,
local_file
=
str
(
self
.
DATASET_PATH
/
"mc_task.json"
),
expected_checksum
=
checksum
)
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
return
False
return
False
...
@@ -85,8 +76,7 @@ class TruthfulQAMultipleChoice(Task):
...
@@ -85,8 +76,7 @@ class TruthfulQAMultipleChoice(Task):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
validation_docs
(
self
):
def
validation_docs
(
self
):
with
open
(
self
.
DATASET_PATH
/
"mc_task.json"
)
as
f
:
return
self
.
dataset
[
"validation"
]
return
json
.
load
(
f
)
def
test_docs
(
self
):
def
test_docs
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -121,7 +111,7 @@ class TruthfulQAMultipleChoice(Task):
...
@@ -121,7 +111,7 @@ class TruthfulQAMultipleChoice(Task):
return
[
rf
.
loglikelihood
(
ctx
,
" "
+
t
)[
0
]
for
t
in
targets
]
return
[
rf
.
loglikelihood
(
ctx
,
" "
+
t
)[
0
]
for
t
in
targets
]
# MC1 and MC2 targets are not always the same set of strings so we collect
# MC1 and MC2 targets are not always the same set of strings so we collect
# likelihoods separately for simpler processing.
# likelihoods separately for simpler processing.
return
get_lls
(
doc
[
'mc1_targets'
])
+
get_lls
(
doc
[
'mc2_targets'
])
return
get_lls
(
doc
[
'mc1_targets'
]
[
"choices"
]
)
+
get_lls
(
doc
[
'mc2_targets'
]
[
"choices"
]
)
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
...
@@ -139,14 +129,14 @@ class TruthfulQAMultipleChoice(Task):
...
@@ -139,14 +129,14 @@ class TruthfulQAMultipleChoice(Task):
def
mc2
(
lls
):
def
mc2
(
lls
):
# Split on the first `0` as everything before it is true (`1`).
# Split on the first `0` as everything before it is true (`1`).
split_idx
=
list
(
doc
[
'mc2_targets'
]
.
values
()
).
index
(
0
)
split_idx
=
list
(
doc
[
'mc2_targets'
]
[
"labels"
]
).
index
(
0
)
# Compute the normalized probability mass for the correct answer.
# Compute the normalized probability mass for the correct answer.
ll_true
,
ll_false
=
lls
[:
split_idx
],
lls
[
split_idx
:]
ll_true
,
ll_false
=
lls
[:
split_idx
],
lls
[
split_idx
:]
p_true
,
p_false
=
np
.
exp
(
np
.
array
(
ll_true
)),
np
.
exp
(
np
.
array
(
ll_false
))
p_true
,
p_false
=
np
.
exp
(
np
.
array
(
ll_true
)),
np
.
exp
(
np
.
array
(
ll_false
))
p_true
=
p_true
/
(
sum
(
p_true
)
+
sum
(
p_false
))
p_true
=
p_true
/
(
sum
(
p_true
)
+
sum
(
p_false
))
return
sum
(
p_true
)
return
sum
(
p_true
)
split_idx
=
len
(
doc
[
'mc1_targets'
])
split_idx
=
len
(
doc
[
'mc1_targets'
]
[
"choices"
]
)
mc1_lls
,
mc2_lls
=
results
[:
split_idx
],
results
[
split_idx
:]
mc1_lls
,
mc2_lls
=
results
[:
split_idx
],
results
[
split_idx
:]
return
{
return
{
"mc1"
:
mc1
(
mc1_lls
),
"mc1"
:
mc1
(
mc1_lls
),
...
@@ -168,19 +158,12 @@ class TruthfulQAMultipleChoice(Task):
...
@@ -168,19 +158,12 @@ class TruthfulQAMultipleChoice(Task):
class
TruthfulQAGeneration
(
Task
):
class
TruthfulQAGeneration
(
Task
):
VERSION
=
1
VERSION
=
1
DATASET_PATH
=
Path
(
'data/truthfulqa/generation'
)
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
truthfulqa
.
truthfulqa
)
DATASET_NAME
=
"generation"
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
bleurt
=
load_metric
(
"bleurt"
,
cache_dir
=
"lm_cache"
)
self
.
bleurt
=
datasets
.
load_metric
(
"bleurt"
)
def
download
(
self
):
if
self
.
DATASET_PATH
.
exists
():
return
Path
.
mkdir
(
self
.
DATASET_PATH
,
parents
=
True
)
url
=
"https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv"
checksum
=
"8d7dd15f033196140f032d97d30f037da7a7b1192c3f36f9937c1850925335a2"
download_file
(
url
,
local_file
=
str
(
self
.
DATASET_PATH
/
"TruthfulQA.csv"
),
expected_checksum
=
checksum
)
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
return
False
return
False
...
@@ -194,36 +177,29 @@ class TruthfulQAGeneration(Task):
...
@@ -194,36 +177,29 @@ class TruthfulQAGeneration(Task):
def
training_docs
(
self
):
def
training_docs
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
_split_multi_answer
(
self
,
answers
,
sep
=
';'
):
def
_format_answers
(
self
,
answers
):
answers
=
answers
.
strip
().
split
(
sep
)
formatted_answers
=
[]
split_answers
=
[]
for
answer
in
answers
:
for
answer
in
answers
:
answer
=
answer
.
strip
()
answer
=
answer
.
strip
()
if
len
(
answer
):
if
len
(
answer
):
# Add a period after all answers.
# Add a period after all answers.
if
answer
[
-
1
]
!=
'.'
:
if
answer
[
-
1
]
!=
'.'
:
split
_answers
.
append
(
answer
+
'.'
)
formatted
_answers
.
append
(
answer
+
'.'
)
else
:
else
:
split
_answers
.
append
(
answer
)
formatted
_answers
.
append
(
answer
)
return
split
_answers
return
formatted
_answers
def
validation_docs
(
self
):
def
validation_docs
(
self
):
with
open
(
self
.
DATASET_PATH
/
"TruthfulQA.csv"
,
newline
=
''
)
as
csvfile
:
for
doc
in
self
.
dataset
[
"validation"
]:
doc_reader
=
csv
.
DictReader
(
csvfile
)
incorrect_answers
=
self
.
_format_answers
(
doc
[
'incorrect_answers'
])
for
doc
in
doc_reader
:
correct_answers
=
self
.
_format_answers
(
doc
[
'correct_answers'
])
# Ensure that references exist.
if
"I have no comment."
not
in
correct_answers
:
if
not
doc
[
'Correct Answers'
]
or
not
doc
[
'Incorrect Answers'
]:
correct_answers
.
append
(
"I have no comment."
)
continue
yield
{
correct_answers
=
self
.
_split_multi_answer
(
doc
[
'Correct Answers'
])
'question'
:
doc
[
'question'
].
strip
(),
if
"I have no comment."
not
in
correct_answers
:
'correct_answers'
:
correct_answers
,
correct_answers
.
append
(
"I have no comment."
)
'incorrect_answers'
:
incorrect_answers
incorrect_answers
=
self
.
_split_multi_answer
(
doc
[
'Incorrect Answers'
])
}
doc
=
{
'question'
:
doc
[
'Question'
].
strip
(),
'correct_answers'
:
correct_answers
,
'incorrect_answers'
:
incorrect_answers
}
yield
doc
def
test_docs
(
self
):
def
test_docs
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
...
...
lm_eval/tasks/unscramble.py
View file @
6caa0afd
...
@@ -8,11 +8,8 @@ addition, or deletion of characters, and asking it to recover the original word.
...
@@ -8,11 +8,8 @@ addition, or deletion of characters, and asking it to recover the original word.
Homepage: https://github.com/openai/gpt-3/tree/master/data
Homepage: https://github.com/openai/gpt-3/tree/master/data
"""
"""
import
gzip
import
inspect
import
json
import
lm_eval.datasets.unscramble.unscramble
import
shutil
from
pathlib
import
Path
from
best_download
import
download_file
from
lm_eval.base
import
Task
,
rf
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
from
lm_eval.metrics
import
mean
...
@@ -32,30 +29,10 @@ _CITATION = """
...
@@ -32,30 +29,10 @@ _CITATION = """
"""
"""
def
extract_gzip
(
gz
,
to
):
with
gzip
.
open
(
gz
,
'rb'
)
as
fin
:
with
open
(
to
,
'wb'
)
as
fout
:
shutil
.
copyfileobj
(
fin
,
fout
)
class
WordUnscrambleTask
(
Task
):
class
WordUnscrambleTask
(
Task
):
VERSION
=
0
VERSION
=
0
BASE_PATH
=
Path
(
"data/unscramble"
)
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
unscramble
.
unscramble
)
FILENAME
=
None
DATASET_NAME
=
None
CHECKSUM
=
None
# SHA256 Checksum.
def
__init__
(
self
):
super
().
__init__
()
def
download
(
self
):
if
not
self
.
BASE_PATH
.
exists
():
Path
.
mkdir
(
self
.
BASE_PATH
,
parents
=
True
)
file
=
self
.
BASE_PATH
/
self
.
FILENAME
if
not
file
.
exists
():
rawfile
=
file
.
parent
/
(
file
.
name
+
".gz"
)
base_url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data"
download_file
(
f
"
{
base_url
}
/
{
self
.
FILENAME
}
.gz"
,
local_file
=
str
(
rawfile
),
expected_checksum
=
self
.
CHECKSUM
)
extract_gzip
(
gz
=
rawfile
,
to
=
file
)
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
return
False
return
False
...
@@ -67,8 +44,7 @@ class WordUnscrambleTask(Task):
...
@@ -67,8 +44,7 @@ class WordUnscrambleTask(Task):
return
False
return
False
def
validation_docs
(
self
):
def
validation_docs
(
self
):
file
=
self
.
BASE_PATH
/
self
.
FILENAME
return
self
.
dataset
[
"validation"
]
return
(
json
.
loads
(
line
)
for
line
in
open
(
file
).
read
().
splitlines
())
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
doc
[
"context"
]
return
doc
[
"context"
]
...
@@ -99,25 +75,20 @@ class WordUnscrambleTask(Task):
...
@@ -99,25 +75,20 @@ class WordUnscrambleTask(Task):
class
Anagrams1
(
WordUnscrambleTask
):
class
Anagrams1
(
WordUnscrambleTask
):
FILENAME
=
"mid_word_1_anagrams.jsonl"
DATASET_NAME
=
"mid_word_1_anagrams"
CHECKSUM
=
"6768a86896083199de4815d4964cb2f6f1046476cfd80c2a562784f182905979"
class
Anagrams2
(
WordUnscrambleTask
):
class
Anagrams2
(
WordUnscrambleTask
):
FILENAME
=
"mid_word_2_anagrams.jsonl"
DATASET_NAME
=
"mid_word_2_anagrams"
CHECKSUM
=
"c3d839d09a7954b78a27cd2cd75d4ed0488656c56ef4dbd741a005343826cb01"
class
CycleLetters
(
WordUnscrambleTask
):
class
CycleLetters
(
WordUnscrambleTask
):
FILENAME
=
"cycle_letters_in_word.jsonl"
DATASET_NAME
=
"cycle_letters_in_word"
CHECKSUM
=
"1689c9002bb8c5988bf5f05e977c9db92f57932c1b5a38998c29ac0dd71e1d42"
class
RandomInsertion
(
WordUnscrambleTask
):
class
RandomInsertion
(
WordUnscrambleTask
):
FILENAME
=
"random_insertion_in_word.jsonl"
DATASET_NAME
=
"random_insertion_in_word"
CHECKSUM
=
"72e65d83da53d15752ee0c47379509de149ddbad32d61184e5991df29616b78a"
class
ReversedWords
(
WordUnscrambleTask
):
class
ReversedWords
(
WordUnscrambleTask
):
FILENAME
=
"reversed_words.jsonl"
DATASET_NAME
=
"reversed_words"
CHECKSUM
=
"133a08f875cd6c1ef8608a3233571a773881cc27b1c707de738cc6543439332a"
lm_eval/tasks/webqs.py
View file @
6caa0afd
...
@@ -9,9 +9,8 @@ The questions are popular ones asked on the web (at least in 2013).
...
@@ -9,9 +9,8 @@ The questions are popular ones asked on the web (at least in 2013).
Homepage: https://worksheets.codalab.org/worksheets/0xba659fe363cb46e7a505c5b6a774dc8a
Homepage: https://worksheets.codalab.org/worksheets/0xba659fe363cb46e7a505c5b6a774dc8a
"""
"""
from
.
common
import
HFTask
from
lm_eval.base
import
rf
,
Task
from
lm_eval.base
import
rf
from
lm_eval.metrics
import
mean
from
..metrics
import
mean
_CITATION
=
"""
_CITATION
=
"""
...
@@ -32,7 +31,7 @@ _CITATION = """
...
@@ -32,7 +31,7 @@ _CITATION = """
"""
"""
class
WebQs
(
HF
Task
):
class
WebQs
(
Task
):
VERSION
=
0
VERSION
=
0
DATASET_PATH
=
"web_questions"
DATASET_PATH
=
"web_questions"
DATASET_NAME
=
None
DATASET_NAME
=
None
...
@@ -46,6 +45,14 @@ class WebQs(HFTask):
...
@@ -46,6 +45,14 @@ class WebQs(HFTask):
def
has_test_docs
(
self
):
def
has_test_docs
(
self
):
return
True
return
True
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
test_docs
(
self
):
return
self
.
dataset
[
"test"
]
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
"Question: "
+
doc
[
'question'
]
+
'
\n
Answer:'
return
"Question: "
+
doc
[
'question'
]
+
'
\n
Answer:'
...
...
lm_eval/tasks/wikitext.py
View file @
6caa0afd
...
@@ -9,11 +9,10 @@ NOTE: This `Task` is based on WikiText-2.
...
@@ -9,11 +9,10 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
"""
import
os
import
re
import
re
from
lm_eval.base
import
rf
,
PerplexityTask
import
inspect
from
lm_eval.
utils
import
sh
import
lm_eval.
datasets.wikitext.wikitext
from
best_download
import
download_file
from
lm_eval.base
import
PerplexityTask
_CITATION
=
"""
_CITATION
=
"""
...
@@ -64,45 +63,33 @@ def wikitext_detokenizer(string):
...
@@ -64,45 +63,33 @@ def wikitext_detokenizer(string):
class
WikiText
(
PerplexityTask
):
class
WikiText
(
PerplexityTask
):
VERSION
=
1
VERSION
=
1
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
wikitext
.
wikitext
)
DATASET_NAME
=
"wikitext-2-raw-v1"
def
download
(
self
):
def
has_training_docs
(
self
):
if
not
os
.
path
.
exists
(
'data/wikitext/wikitext-2-raw/wiki.valid.raw'
):
os
.
makedirs
(
"data/wikitext/"
,
exist_ok
=
True
)
download_file
(
"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip"
,
local_file
=
"data/wikitext/wikitext-2-raw-v1.zip"
,
expected_checksum
=
"ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11"
)
sh
(
"cd data/wikitext/ && unzip wikitext-2-raw-v1.zip"
)
def
has_validation_docs
(
self
):
return
True
return
True
def
has_
trai
n_docs
(
self
):
def
has_
validatio
n_docs
(
self
):
return
True
return
True
def
has_test_docs
(
self
):
def
has_test_docs
(
self
):
return
True
return
True
def
docs_for_split
(
self
,
split
):
ret
=
[]
for
line
in
open
(
f
"data/wikitext/wikitext-2-raw/wiki.
{
split
}
.raw"
).
read
().
split
(
'
\n
'
):
rline
=
line
.
replace
(
"= = ="
,
"==="
).
replace
(
"= ="
,
"=="
).
strip
()
if
rline
.
startswith
(
'= '
)
and
rline
.
strip
().
endswith
(
' ='
):
s
=
'
\n
'
.
join
(
ret
)
if
s
.
strip
():
yield
s
ret
=
[]
ret
.
append
(
line
)
yield
'
\n
'
.
join
(
ret
)
def
validation
_docs
(
self
):
def
training
_docs
(
self
):
return
self
.
docs_for_split
(
'valid'
)
return
map
(
self
.
_load_doc
,
self
.
dataset
[
"train"
]
)
def
trai
n_docs
(
self
):
def
validatio
n_docs
(
self
):
return
self
.
docs_for_split
(
'train'
)
return
map
(
self
.
_load_doc
,
self
.
dataset
[
"validation"
]
)
def
test_docs
(
self
):
def
test_docs
(
self
):
return
self
.
docs_for_split
(
'test'
)
return
map
(
self
.
_load_doc
,
self
.
dataset
[
"test"
])
def
_load_doc
(
self
,
doc
):
return
doc
[
"page"
]
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
wikitext_detokenizer
(
doc
)
return
wikitext_detokenizer
(
doc
)
def
count_words
(
self
,
doc
):
def
count_words
(
self
,
doc
):
# count number of words in *original doc before detokenization*
# count number of words in *original doc before detokenization*
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
lm_eval/tasks/winogrande.py
View file @
6caa0afd
...
@@ -15,9 +15,8 @@ See: https://arxiv.org/abs/1806.02847
...
@@ -15,9 +15,8 @@ See: https://arxiv.org/abs/1806.02847
Homepage: https://leaderboard.allenai.org/winogrande/submissions/public
Homepage: https://leaderboard.allenai.org/winogrande/submissions/public
"""
"""
import
numpy
as
np
import
numpy
as
np
from
.
common
import
HFTask
from
lm_eval.base
import
rf
,
Task
from
lm_eval.base
import
rf
from
lm_eval.metrics
import
mean
from
..metrics
import
mean
_CITATION
=
"""
_CITATION
=
"""
...
@@ -30,7 +29,7 @@ _CITATION = """
...
@@ -30,7 +29,7 @@ _CITATION = """
"""
"""
class
Winogrande
(
HF
Task
):
class
Winogrande
(
Task
):
VERSION
=
0
VERSION
=
0
DATASET_PATH
=
"winogrande"
DATASET_PATH
=
"winogrande"
DATASET_NAME
=
"winogrande_xl"
DATASET_NAME
=
"winogrande_xl"
...
@@ -46,6 +45,14 @@ class Winogrande(HFTask):
...
@@ -46,6 +45,14 @@ class Winogrande(HFTask):
def
has_test_docs
(
self
):
def
has_test_docs
(
self
):
return
False
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
self
.
partial_context
(
doc
,
doc
[
"option"
+
doc
[
"answer"
]])
return
self
.
partial_context
(
doc
,
doc
[
"option"
+
doc
[
"answer"
]])
...
...
lm_eval/tasks/wsc273.py
View file @
6caa0afd
...
@@ -14,10 +14,8 @@ See: https://arxiv.org/abs/1806.0
...
@@ -14,10 +14,8 @@ See: https://arxiv.org/abs/1806.0
Homepage: https://cs.nyu.edu/~davise/papers/WinogradSchemas/WS.html
Homepage: https://cs.nyu.edu/~davise/papers/WinogradSchemas/WS.html
"""
"""
import
numpy
as
np
import
numpy
as
np
import
random
from
lm_eval.base
import
rf
,
Task
from
lm_eval.base
import
rf
from
lm_eval.metrics
import
mean
from
..metrics
import
mean
from
.
common
import
HFTask
_CITATION
=
"""
_CITATION
=
"""
...
@@ -37,7 +35,7 @@ _CITATION = """
...
@@ -37,7 +35,7 @@ _CITATION = """
"""
"""
class
WinogradSchemaChallenge273
(
HF
Task
):
class
WinogradSchemaChallenge273
(
Task
):
VERSION
=
0
VERSION
=
0
DATASET_PATH
=
"winograd_wsc"
DATASET_PATH
=
"winograd_wsc"
DATASET_NAME
=
"wsc273"
DATASET_NAME
=
"wsc273"
...
@@ -45,19 +43,24 @@ class WinogradSchemaChallenge273(HFTask):
...
@@ -45,19 +43,24 @@ class WinogradSchemaChallenge273(HFTask):
upper_pronouns
=
[
"A"
,
"An"
,
"The"
,
"She"
,
"He"
,
upper_pronouns
=
[
"A"
,
"An"
,
"The"
,
"She"
,
"He"
,
"It"
,
"They"
,
"My"
,
"His"
,
"Her"
,
"Their"
]
"It"
,
"They"
,
"My"
,
"His"
,
"Her"
,
"Their"
]
def
__init__
(
self
):
def
has_training_docs
(
self
):
super
().
__init__
()
return
False
self
.
data
=
self
.
__clean_data
()
def
has_validation_docs
(
self
):
return
False
def
has_test_docs
(
self
):
return
True
def
test_docs
(
self
):
return
map
(
self
.
_load_doc
,
self
.
dataset
[
"test"
])
def
_
_clean_data
(
self
):
def
_
load_doc
(
self
,
doc
):
# The HF implementation of `wsc273` is not `partial evaluation` friendly.
# The HF implementation of `wsc273` is not `partial evaluation` friendly.
data
=
[]
doc
[
"text"
]
=
doc
[
"text"
].
replace
(
" "
,
" "
)
for
doc
in
self
.
data
[
"test"
]:
doc
[
"options"
][
0
]
=
self
.
__normalize_option
(
doc
,
doc
[
"options"
][
0
])
doc
[
"text"
]
=
doc
[
"text"
].
replace
(
" "
,
" "
)
doc
[
"options"
][
1
]
=
self
.
__normalize_option
(
doc
,
doc
[
"options"
][
1
])
doc
[
"options"
][
0
]
=
self
.
__normalize_option
(
doc
,
doc
[
"options"
][
0
])
return
doc
doc
[
"options"
][
1
]
=
self
.
__normalize_option
(
doc
,
doc
[
"options"
][
1
])
data
.
append
(
doc
)
return
{
"test"
:
data
}
def
__normalize_option
(
self
,
doc
,
option
):
def
__normalize_option
(
self
,
doc
,
option
):
# Append `'s` to possessive determiner based options.
# Append `'s` to possessive determiner based options.
...
@@ -70,15 +73,6 @@ class WinogradSchemaChallenge273(HFTask):
...
@@ -70,15 +73,6 @@ class WinogradSchemaChallenge273(HFTask):
return
option
.
replace
(
pronoun
,
pronoun
.
lower
())
return
option
.
replace
(
pronoun
,
pronoun
.
lower
())
return
option
return
option
def
has_training_docs
(
self
):
return
False
def
has_validation_docs
(
self
):
return
False
def
has_test_docs
(
self
):
return
True
def
fewshot_examples
(
self
,
k
,
rnd
):
def
fewshot_examples
(
self
,
k
,
rnd
):
# NOTE: `super().fewshot_examples` samples from training docs which are
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
# not available for this test-set-only dataset.
...
...
setup.py
View file @
6caa0afd
...
@@ -21,8 +21,7 @@ setuptools.setup(
...
@@ -21,8 +21,7 @@ setuptools.setup(
python_requires
=
'>=3.6'
,
python_requires
=
'>=3.6'
,
install_requires
=
[
install_requires
=
[
"black"
,
"black"
,
"best_download==0.0.9"
,
"datasets==2.0.0"
,
"datasets==1.15.1"
,
"click>=7.1"
,
"click>=7.1"
,
"scikit-learn>=0.24.1"
,
"scikit-learn>=0.24.1"
,
"torch>=1.7"
,
"torch>=1.7"
,
...
@@ -43,6 +42,7 @@ setuptools.setup(
...
@@ -43,6 +42,7 @@ setuptools.setup(
"openai==0.6.4"
,
"openai==0.6.4"
,
"jieba==0.42.1"
,
"jieba==0.42.1"
,
"nagisa==0.2.7"
,
"nagisa==0.2.7"
,
"bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt"
],
],
dependency_links
=
[
dependency_links
=
[
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt"
,
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt"
,
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment