Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
11f614b0
Unverified
Commit
11f614b0
authored
Apr 30, 2022
by
Stella Biderman
Committed by
GitHub
Apr 30, 2022
Browse files
Merge branch 'master' into task_doc
parents
0a6a9b7e
e00d682f
Changes
129
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
874 additions
and
284 deletions
+874
-284
lm_eval/datasets/wikitext/wikitext.py
lm_eval/datasets/wikitext/wikitext.py
+227
-0
lm_eval/decontamination/__init__.py
lm_eval/decontamination/__init__.py
+0
-0
lm_eval/decontamination/archiver.py
lm_eval/decontamination/archiver.py
+147
-0
lm_eval/decontamination/decontaminate.py
lm_eval/decontamination/decontaminate.py
+153
-0
lm_eval/decontamination/janitor.py
lm_eval/decontamination/janitor.py
+0
-0
lm_eval/evaluator.py
lm_eval/evaluator.py
+39
-6
lm_eval/metrics.py
lm_eval/metrics.py
+8
-1
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+6
-1
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+8
-1
lm_eval/tasks/anli.py
lm_eval/tasks/anli.py
+15
-7
lm_eval/tasks/arc.py
lm_eval/tasks/arc.py
+19
-3
lm_eval/tasks/arithmetic.py
lm_eval/tasks/arithmetic.py
+32
-60
lm_eval/tasks/asdiv.py
lm_eval/tasks/asdiv.py
+15
-43
lm_eval/tasks/blimp.py
lm_eval/tasks/blimp.py
+18
-8
lm_eval/tasks/cbt.py
lm_eval/tasks/cbt.py
+28
-3
lm_eval/tasks/common.py
lm_eval/tasks/common.py
+0
-52
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+20
-23
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+39
-36
lm_eval/tasks/glue.py
lm_eval/tasks/glue.py
+94
-17
lm_eval/tasks/gsm8k.py
lm_eval/tasks/gsm8k.py
+6
-23
No files found.
lm_eval/datasets/wikitext/wikitext.py
0 → 100644
View file @
11f614b0
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# NOTE: This is a modified version of https://github.com/huggingface/datasets/blob/master/datasets/wikitext/wikitext.py
# that returns Wiki pages instead of Wiki text line-by-line.
"""WikiText Dataset."""
import
os
import
datasets
_CITATION
=
"""
\
@misc{merity2016pointer,
title={Pointer Sentinel Mixture Models},
author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
year={2016},
eprint={1609.07843},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
_DESCRIPTION
=
"""
\
The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified
Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike
License.
"""
_HOMEPAGE
=
"https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/"
_LICENSE
=
"Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)"
_DATA_URL
=
"https://s3.amazonaws.com/research.metamind.io/wikitext"
class
WikitextConfig
(
datasets
.
BuilderConfig
):
"""BuilderConfig for GLUE."""
def
__init__
(
self
,
data_url
,
**
kwargs
):
"""BuilderConfig for Wikitext
Args:
data_url: `string`, url to the dataset (word or raw level)
**kwargs: keyword arguments forwarded to super.
"""
super
(
WikitextConfig
,
self
).
__init__
(
version
=
datasets
.
Version
(
"1.0.0"
,
),
**
kwargs
,
)
self
.
data_url
=
data_url
class
Wikitext
(
datasets
.
GeneratorBasedBuilder
):
"""TODO(wikitext_103): Short description of my dataset."""
# TODO(wikitext_103): Set up version.
VERSION
=
datasets
.
Version
(
"0.1.0"
)
BUILDER_CONFIGS
=
[
WikitextConfig
(
name
=
"wikitext-103-v1"
,
data_url
=
_DATA_URL
+
"/"
+
"wikitext-103-v1.zip"
,
description
=
"Word level dataset. No processing is needed other than replacing newlines with <eos> tokens."
,
),
WikitextConfig
(
name
=
"wikitext-2-v1"
,
data_url
=
_DATA_URL
+
"/"
+
"wikitext-2-v1.zip"
,
description
=
"Word level dataset. No processing is needed other than replacing newlines with <eos> tokens."
,
),
WikitextConfig
(
name
=
"wikitext-103-raw-v1"
,
data_url
=
_DATA_URL
+
"/"
+
"wikitext-103-raw-v1.zip"
,
description
=
"Raw level dataset: the raw tokens before the addition of <unk> tokens. "
"They should only be used for character level work or for creating newly derived datasets."
,
),
WikitextConfig
(
name
=
"wikitext-2-raw-v1"
,
data_url
=
_DATA_URL
+
"/"
+
"wikitext-2-raw-v1.zip"
,
description
=
"Raw level dataset: the raw tokens before the addition of <unk> tokens. "
"They should only be used for character level work or for creating newly derived datasets."
,
),
]
def
_info
(
self
):
# TODO(wikitext): Specifies the datasets.DatasetInfo object
return
datasets
.
DatasetInfo
(
# This is the description that will appear on the datasets page.
description
=
_DESCRIPTION
,
# datasets.features.FeatureConnectors
features
=
datasets
.
Features
(
{
"page"
:
datasets
.
Value
(
"string"
)
# These are the features of your dataset like images, labels ...
}
),
# If there's a common (input, target) tuple from the features,
# specify them here. They'll be used if as_supervised=True in
# builder.as_dataset.
supervised_keys
=
None
,
homepage
=
_HOMEPAGE
,
license
=
_LICENSE
,
citation
=
_CITATION
,
)
def
_split_generators
(
self
,
dl_manager
):
"""Returns SplitGenerators."""
# TODO(wikitext): Downloads the data and defines the splits
# dl_manager is a datasets.download.DownloadManager that can be used to
# download and extract URLs
if
self
.
config
.
name
==
"wikitext-103-v1"
:
data_file
=
dl_manager
.
download_and_extract
(
self
.
config
.
data_url
)
data_dir
=
os
.
path
.
join
(
data_file
,
"wikitext-103"
)
return
[
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.tokens"
),
"split"
:
"test"
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.tokens"
),
"split"
:
"train"
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
VALIDATION
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.tokens"
),
"split"
:
"valid"
},
),
]
else
:
if
self
.
config
.
name
==
"wikitext-103-raw-v1"
:
data_file
=
dl_manager
.
download_and_extract
(
self
.
config
.
data_url
)
data_dir
=
os
.
path
.
join
(
data_file
,
"wikitext-103-raw"
)
return
[
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.raw"
),
"split"
:
"test"
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.raw"
),
"split"
:
"train"
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
VALIDATION
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.raw"
),
"split"
:
"valid"
},
),
]
else
:
if
self
.
config
.
name
==
"wikitext-2-raw-v1"
:
data_file
=
dl_manager
.
download_and_extract
(
self
.
config
.
data_url
)
data_dir
=
os
.
path
.
join
(
data_file
,
"wikitext-2-raw"
)
return
[
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.raw"
),
"split"
:
"test"
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.raw"
),
"split"
:
"train"
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
VALIDATION
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.raw"
),
"split"
:
"valid"
},
),
]
else
:
if
self
.
config
.
name
==
"wikitext-2-v1"
:
data_file
=
dl_manager
.
download_and_extract
(
self
.
config
.
data_url
)
data_dir
=
os
.
path
.
join
(
data_file
,
"wikitext-2"
)
return
[
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.tokens"
),
"split"
:
"test"
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.tokens"
),
"split"
:
"train"
,
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
VALIDATION
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.tokens"
),
"split"
:
"valid"
,
},
),
]
def
_generate_examples
(
self
,
data_file
,
split
):
"""Yields examples."""
with
open
(
data_file
,
encoding
=
"utf-8"
)
as
f
:
key
=
0
ret
=
[]
data
=
f
.
read
().
split
(
"
\n
"
)
for
line
in
data
:
rline
=
line
.
replace
(
"= = ="
,
"==="
).
replace
(
"= ="
,
"=="
).
strip
()
if
rline
.
startswith
(
'= '
)
and
rline
.
strip
().
endswith
(
' ='
):
page
=
'
\n
'
.
join
(
ret
)
if
page
.
strip
():
yield
key
,
{
"page"
:
page
}
key
+=
1
ret
=
[]
ret
.
append
(
line
)
page
=
'
\n
'
.
join
(
ret
)
yield
key
,
{
"page"
:
page
}
lm_eval/decontamination/__init__.py
0 → 100644
View file @
11f614b0
scripts/clean_training_data
/archiver.py
→
lm_eval/decontamination
/archiver.py
View file @
11f614b0
...
...
@@ -4,6 +4,9 @@ import json
import
jsonlines
import
io
import
datetime
import
mmap
import
tqdm
from
pathlib
import
Path
def
json_serial
(
obj
):
"""JSON serializer for objects not serializable by default json code"""
...
...
@@ -59,16 +62,19 @@ class Reader:
else
:
yield
text
# Simple text reader and writer with same interface as above
class
TextArchive
:
def
__init__
(
self
,
file_path
,
mode
=
"
ab
"
):
def
__init__
(
self
,
file_path
,
mode
=
"
rb+
"
):
self
.
file_path
=
file_path
dir_name
=
os
.
path
.
dirname
(
file_path
)
if
dir_name
:
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
self
.
fh
=
open
(
self
.
file_path
,
mode
)
if
not
os
.
path
.
exists
(
file_path
):
Path
(
file_path
).
touch
()
self
.
fh
=
open
(
self
.
file_path
,
mode
)
def
add_data
(
self
,
data
,
meta
=
{}
):
def
add_data
(
self
,
data
):
self
.
fh
.
write
(
data
.
encode
(
'UTF-8'
)
+
b
'
\n
'
)
def
commit
(
self
):
...
...
@@ -79,12 +85,63 @@ class TextReader:
def
__init__
(
self
,
file_path
):
self
.
file_path
=
file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def
read_tqdm
(
self
,
update_frequency
=
10000
):
current_file_position
=
0
line_counter
=
0
with
open
(
self
.
file_path
,
'r'
)
as
fh
,
\
tqdm
.
tqdm
(
total
=
os
.
path
.
getsize
(
self
.
file_path
),
dynamic_ncols
=
True
,
unit
=
"byte"
,
unit_scale
=
1
)
as
progress
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
line_counter
+=
1
if
line_counter
==
update_frequency
:
new_file_pos
=
mmap_obj
.
tell
()
bytes_read
=
new_file_pos
-
current_file_position
current_file_position
=
new_file_pos
progress
.
update
(
bytes_read
)
line_counter
=
0
yield
line
[:
-
1
]
def
read_and_tell
(
self
):
current_file_position
=
0
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
new_file_pos
=
mmap_obj
.
tell
()
raw_bytes_read
=
new_file_pos
-
current_file_position
current_file_position
=
new_file_pos
yield
line
[:
-
1
],
raw_bytes_read
def
read
(
self
):
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
self
.
fh
=
fh
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
yield
line
[:
-
1
]
def
read_slow
(
self
):
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
while
True
:
line
=
self
.
fh
.
readline
()
line
=
fh
.
readline
()
if
line
==
-
1
or
line
==
""
:
break
else
:
yield
line
[:
-
1
]
\ No newline at end of file
else
:
yield
line
[:
-
1
]
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class
ZStdTextReader
:
def
__init__
(
self
,
file
):
self
.
file
=
file
def
read_tqdm
(
self
):
decompressed_file
=
self
.
file
[:
-
4
]
print
(
"Decompressing file, please wait..."
)
os
.
system
(
f
"zstd -d
{
self
.
file
}
"
)
# linux decompress is faster
reader
=
TextReader
(
decompressed_file
)
yield
from
reader
.
read_tqdm
()
os
.
remove
(
decompressed_file
)
\ No newline at end of file
lm_eval/decontamination/decontaminate.py
0 → 100644
View file @
11f614b0
import
time
import
random
import
pickle
import
json
import
glob
import
os
import
collections
from
.janitor
import
Janitor
,
word_ngrams
from
.archiver
import
ZStdTextReader
# Was used for testing the evaluator decoupled from the full logic below
def
get_train_overlap_stub
(
docs
,
ngrams_path
,
ngrams_n_size
):
simulated_overlap
=
0.1
contaminated
=
int
(
len
(
docs
)
*
simulated_overlap
)
return
random
.
sample
(
range
(
len
(
docs
)),
contaminated
)
# Returns a dictionary containing all overlapping documents in each
# task. In the standard use case, an overlap occurs when any of the 13-grams
# found in the task document exist in the training set documents.
#
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.
# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
# 3. Full scan the 13-grams from the training set against the merged lookup,
# saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
def
get_train_overlap
(
docs_by_task_set
,
ngrams_path
,
limit
):
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
info_dict_path
=
os
.
path
.
join
(
ngrams_path
,
"info.json"
)
info_dict
=
json
.
load
(
open
(
info_dict_path
,
"r"
))
ngrams_n_size
=
info_dict
[
"ngram_size"
]
janitor
=
Janitor
()
# Build lookup for each dataset first in case we use different task combinations later
print
(
"Building Lookups..."
)
start
=
time
.
perf_counter
()
def
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
):
return
f
"data/
{
task_name
}
/
{
task_set
}
_
{
ngrams_n_size
}
grams_limit
{
limit
}
.overlaps"
lookups
=
{}
duplicates
=
{}
# (task_name, task_set): set(doc_ids)}
sets_to_decontaminate
=
len
(
docs_by_task_set
.
keys
())
for
(
task_name
,
task_set
),
docs
in
docs_by_task_set
.
items
():
if
not
os
.
path
.
exists
(
f
"data/
{
task_name
}
"
):
os
.
mkdir
(
f
"data/
{
task_name
}
"
)
# Check if we've decontaminated this combination before
overlaps_dump_path
=
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
)
if
os
.
path
.
exists
(
overlaps_dump_path
):
duplicates
[(
task_name
,
task_set
)]
=
pickle
.
load
(
open
(
overlaps_dump_path
,
"rb"
))
sets_to_decontaminate
-=
1
continue
else
:
duplicates
[(
task_name
,
task_set
)]
=
set
()
# Build/load the task lookup {ngram: set(documents)}.
task_set_lookup_path
=
f
"data/
{
task_name
}
/
{
task_set
}
_
{
ngrams_n_size
}
grams_limit
{
limit
}
.lookup"
if
os
.
path
.
exists
(
task_set_lookup_path
):
print
(
f
"
{
task_set_lookup_path
}
available, loading..."
)
lookups
[(
task_name
,
task_set
)]
=
pickle
.
load
(
open
(
task_set_lookup_path
,
"rb"
))
else
:
print
(
f
"
{
task_set_lookup_path
}
not available, building..."
)
lookup
=
collections
.
defaultdict
(
set
)
for
doc_id
,
document
in
enumerate
(
docs
):
ngrams
=
word_ngrams
(
janitor
.
normalize_string
(
document
),
ngrams_n_size
)
for
ngram
in
ngrams
:
lookup
[
ngram
].
add
(
doc_id
)
pickle
.
dump
(
lookup
,
open
(
task_set_lookup_path
,
"wb"
))
lookups
[(
task_name
,
task_set
)]
=
lookup
elapsed
=
time
.
perf_counter
()
-
start
print
(
f
"Building lookups took
{
elapsed
:
0.5
f
}
seconds."
)
matched_ngrams
=
[]
if
sets_to_decontaminate
>
0
:
print
(
"Merging lookups..."
)
start
=
time
.
perf_counter
()
merged_lookup
=
collections
.
defaultdict
(
list
)
for
(
task_name
,
task_set
),
lookup
in
lookups
.
items
():
for
ngram
,
doc_ids
in
lookup
.
items
():
merged_lookup
[
ngram
].
append
((
task_name
,
task_set
,
doc_ids
))
elapsed
=
time
.
perf_counter
()
-
start
print
(
f
"Merging lookups took
{
elapsed
:
0.5
f
}
seconds."
)
print
(
f
"
{
ngrams_n_size
}
grams files found in
{
ngrams_path
}
:"
)
files
=
glob
.
glob
(
os
.
path
.
join
(
ngrams_path
,
f
"*.sorted.zst"
))
print
(
files
)
for
file
in
files
:
start
=
time
.
perf_counter
()
print
(
f
"Scanning
{
file
}
"
)
reader
=
ZStdTextReader
(
file
)
total_ngrams
=
0
unique_ngrams
=
0
matching_unique
=
0
non_matching_unique
=
0
current_ngram
=
""
for
line
in
reader
.
read_tqdm
():
# Scan training set ngrams file
total_ngrams
+=
1
[
ngram
,
document_id
]
=
line
.
rsplit
(
" "
,
1
)
if
ngram
!=
current_ngram
:
# Only need to match the ngram once in training set
unique_ngrams
+=
1
current_ngram
=
ngram
if
ngram
in
merged_lookup
:
matched_ngrams
.
append
(
ngram
)
# For logging
matching_unique
+=
1
for
task_name
,
task_set
,
doc_ids
in
merged_lookup
[
ngram
]:
task_doc_set
=
duplicates
[(
task_name
,
task_set
)]
for
doc_id
in
doc_ids
:
# Record contamination across all relevant task/set combos
task_doc_set
.
add
(
doc_id
)
del
merged_lookup
[
ngram
]
# No point matching again
else
:
non_matching_unique
+=
1
print
(
f
"Total Ngrams:
{
total_ngrams
}
"
)
print
(
f
"Unique Ngrams:
{
unique_ngrams
}
"
)
print
(
f
"Unique Matching:
{
matching_unique
}
"
)
print
(
f
"Unique Non Matching:
{
non_matching_unique
}
"
)
print
(
"Matched ngrams:"
)
for
ngram
in
matched_ngrams
:
print
(
ngram
)
elapsed
=
time
.
perf_counter
()
-
start
print
(
f
"Read took
{
elapsed
:
0.5
f
}
seconds."
)
print
(
f
"Speed:
{
(
os
.
path
.
getsize
(
file
)
/
1000000.0
)
/
elapsed
}
MB/second"
)
print
(
duplicates
)
# Dump overlaps separately
for
(
task_name
,
task_set
),
doc_ids
in
duplicates
.
items
():
overlaps_dump_path
=
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
)
pickle
.
dump
(
doc_ids
,
open
(
overlaps_dump_path
,
"wb"
))
# Strip task set and return
return
{
task_name
:
doc_ids
for
(
task_name
,
task_set
),
doc_ids
in
duplicates
.
items
()}
scripts/clean_training_data
/janitor.py
→
lm_eval/decontamination
/janitor.py
View file @
11f614b0
File moved
lm_eval/evaluator.py
View file @
11f614b0
...
...
@@ -6,15 +6,18 @@ import lm_eval.metrics
import
lm_eval.models
import
lm_eval.tasks
import
lm_eval.base
import
lm_eval.decontamination
import
numpy
as
np
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
from
lm_eval.decontamination.decontaminate
import
get_train_overlap
@
positional_deprecated
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
check_integrity
=
False
):
description_dict
=
None
,
check_integrity
=
False
,
decontamination_ngrams_path
=
None
):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
...
...
@@ -72,7 +75,8 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict
=
task_dict
,
num_fewshot
=
num_fewshot
,
limit
=
limit
,
description_dict
=
description_dict
description_dict
=
description_dict
,
decontamination_ngrams_path
=
decontamination_ngrams_path
,
)
# add info about the model and few shot config
...
...
@@ -90,9 +94,11 @@ def simple_evaluate(model, model_args=None, tasks=[],
return
results
decontaminate_suffix
=
"_decontaminate"
@
positional_deprecated
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
):
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
decontamination_ngrams_path
=
None
):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
...
...
@@ -120,6 +126,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
decontaminate
=
decontamination_ngrams_path
is
not
None
task_dict_items
=
[
(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
...
...
@@ -132,6 +140,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
requests
=
collections
.
defaultdict
(
list
)
requests_origin
=
collections
.
defaultdict
(
list
)
overlaps
=
collections
.
defaultdict
(
list
)
# {task_name: contaminated_docs}
# If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
# memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
# over-engineering is bad (or we could make it write the requests to disk and then read them back out again
...
...
@@ -140,6 +150,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs
=
{}
docs_for_decontamination
=
collections
.
defaultdict
(
list
)
# get lists of each type of request
for
task_name
,
task
in
task_dict_items
:
versions
[
task_name
]
=
task
.
VERSION
...
...
@@ -147,7 +159,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if
task
.
has_test_docs
():
task_doc_func
=
task
.
test_docs
task_set
=
"test"
# Required for caching in the decontamination
elif
task
.
has_validation_docs
():
task_set
=
"val"
# Required for caching in the decontamination
task_doc_func
=
task
.
validation_docs
else
:
raise
RuntimeError
(
"Task has neither test_docs nor validation_docs"
)
...
...
@@ -161,6 +175,10 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
description
=
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
if
decontaminate
and
task
.
should_decontaminate
():
docs_for_decontamination
[(
task_name
,
task_set
)].
append
(
task
.
doc_to_decontamination_query
(
doc
))
docs
[(
task_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
doc
=
doc
,
...
...
@@ -177,6 +195,11 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin
[
req
.
request_type
].
append
((
i
,
task_name
,
doc
,
doc_id
))
# Compare all tasks/sets at once to ensure a single training set scan
if
decontaminate
:
print
(
"Finding train/test overlap, please wait..."
)
overlaps
=
get_train_overlap
(
docs_for_decontamination
,
decontamination_ngrams_path
,
limit
)
# all responses for each (task, doc)
process_res_queue
=
collections
.
defaultdict
(
list
)
...
...
@@ -207,18 +230,28 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
metrics
=
task
.
process_results
(
doc
,
requests
)
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_name
,
metric
)].
append
(
value
)
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps
if
decontaminate
and
task_name
in
overlaps
:
if
doc_id
not
in
overlaps
[
task_name
]:
vals
[(
task_name
,
metric
+
decontaminate_suffix
)].
append
(
value
)
# aggregate results
for
(
task_name
,
metric
),
items
in
vals
.
items
():
task
=
task_dict
[
task_name
]
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
real_metric
=
metric
# key when looking up the metric with task.aggregation
if
metric
.
endswith
(
decontaminate_suffix
):
real_metric
=
metric
.
replace
(
decontaminate_suffix
,
""
)
# decontaminated still uses the same metric
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
real_metric
](
items
)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
metric
=
task
.
aggregation
()[
metric
],
metric
=
task
.
aggregation
()[
real_
metric
],
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
)
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
...
...
lm_eval/metrics.py
View file @
11f614b0
...
...
@@ -3,7 +3,7 @@ from collections.abc import Iterable
import
numpy
as
np
import
sacrebleu
import
sklearn
import
sklearn
.metrics
import
random
...
...
@@ -245,3 +245,10 @@ def stderr_for_metric(metric, bootstrap_iters):
}
return
stderr
.
get
(
metric
,
None
)
def
yesno
(
x
):
if
x
:
return
'yes'
else
:
return
'no'
lm_eval/models/gpt2.py
View file @
11f614b0
...
...
@@ -12,9 +12,14 @@ class HFLM(BaseLM):
assert
isinstance
(
pretrained
,
str
)
assert
isinstance
(
batch_size
,
int
)
if
device
:
if
device
:
if
device
not
in
[
"cuda"
,
"cpu"
]:
device
=
int
(
device
)
self
.
_device
=
torch
.
device
(
device
)
print
(
f
"Using device '
{
device
}
'"
)
else
:
print
(
"Device not specificed"
)
print
(
f
"Cuda Available?
{
torch
.
cuda
.
is_available
()
}
"
)
self
.
_device
=
torch
.
device
(
'cuda'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
# TODO: update this to be less of a hack once subfolder is fixed in HF
...
...
lm_eval/tasks/__init__.py
View file @
11f614b0
...
...
@@ -15,6 +15,7 @@ from . import wsc273
from
.
import
winogrande
from
.
import
quac
from
.
import
hellaswag
from
.
import
swag
from
.
import
openbookqa
from
.
import
squad
from
.
import
naturalqs
...
...
@@ -50,6 +51,7 @@ from . import truthfulqa
from
.
import
blimp
from
.
import
asdiv
from
.
import
gsm8k
from
.
import
storycloze
########################################
# Translation tasks
...
...
@@ -135,8 +137,8 @@ TASK_REGISTRY = {
# "quac": quac.QuAC, # not implemented yet
"logiqa"
:
logiqa
.
LogiQA
,
"hellaswag"
:
hellaswag
.
HellaSwag
,
"swag"
:
swag
.
SWAG
,
"openbookqa"
:
openbookqa
.
OpenBookQA
,
# "sat": sat.SATAnalogies, # not implemented yet
"squad2"
:
squad
.
SQuAD2
,
"race"
:
race
.
RACE
,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
...
...
@@ -297,6 +299,11 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance"
:
blimp
.
BlimpWhVsThatNoGapLongDistance
,
"blimp_wh_vs_that_with_gap"
:
blimp
.
BlimpWhVsThatWithGap
,
"blimp_wh_vs_that_with_gap_long_distance"
:
blimp
.
BlimpWhVsThatWithGapLongDistance
,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
# "sat": sat.SATAnalogies,
}
...
...
lm_eval/tasks/anli.py
View file @
11f614b0
...
...
@@ -10,9 +10,8 @@ provided explanations.
Homepage: "https://github.com/facebookresearch/anli"
"""
import
numpy
as
np
from
lm_eval.base
import
rf
from
..metrics
import
mean
from
.
common
import
HFTask
from
lm_eval.base
import
rf
,
Task
from
lm_eval.metrics
import
mean
_CITATION
=
"""
...
...
@@ -31,7 +30,7 @@ _CITATION = """
"""
class
ANLIBase
(
HF
Task
):
class
ANLIBase
(
Task
):
VERSION
=
0
DATASET_PATH
=
"anli"
DATASET_NAME
=
None
...
...
@@ -49,16 +48,16 @@ class ANLIBase(HFTask):
def
training_docs
(
self
):
if
self
.
has_training_docs
():
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
data
[
"train_r"
+
str
(
self
.
SPLIT
)])
self
.
_training_docs
=
list
(
self
.
data
set
[
"train_r"
+
str
(
self
.
SPLIT
)])
return
self
.
_training_docs
def
validation_docs
(
self
):
if
self
.
has_validation_docs
():
return
self
.
data
[
"dev_r"
+
str
(
self
.
SPLIT
)]
return
self
.
data
set
[
"dev_r"
+
str
(
self
.
SPLIT
)]
def
test_docs
(
self
):
if
self
.
has_test_docs
():
return
self
.
data
[
"test_r"
+
str
(
self
.
SPLIT
)]
return
self
.
data
set
[
"test_r"
+
str
(
self
.
SPLIT
)]
def
doc_to_text
(
self
,
doc
):
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
...
...
@@ -67,6 +66,12 @@ class ANLIBase(HFTask):
# want to do it exactly as OA did?
return
doc
[
'premise'
]
+
'
\n
Question: '
+
doc
[
'hypothesis'
]
+
' True, False, or Neither?
\n
Answer:'
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"premise"
]
def
doc_to_target
(
self
,
doc
):
# True = entailment
# False = contradiction
...
...
@@ -125,11 +130,14 @@ class ANLIBase(HFTask):
"acc"
:
True
}
class
ANLIRound1
(
ANLIBase
):
SPLIT
=
1
class
ANLIRound2
(
ANLIBase
):
SPLIT
=
2
class
ANLIRound3
(
ANLIBase
):
SPLIT
=
3
lm_eval/tasks/arc.py
View file @
11f614b0
...
...
@@ -13,7 +13,6 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc
"""
from
lm_eval.base
import
MultipleChoiceTask
from
.
common
import
HFTask
_CITATION
=
"""
...
...
@@ -27,7 +26,7 @@ _CITATION = """
"""
class
ARCEasy
(
HFTask
,
MultipleChoiceTask
):
class
ARCEasy
(
MultipleChoiceTask
):
VERSION
=
0
DATASET_PATH
=
"ai2_arc"
DATASET_NAME
=
"ARC-Easy"
...
...
@@ -41,7 +40,18 @@ class ARCEasy(HFTask, MultipleChoiceTask):
def
has_test_docs
(
self
):
return
True
def
_convert_standard
(
self
,
doc
):
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
map
(
self
.
_process_doc
,
self
.
dataset
[
"train"
]))
return
self
.
_training_docs
def
validation_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"validation"
])
def
test_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"test"
])
def
_process_doc
(
self
,
doc
):
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter
=
{
"1"
:
"A"
,
"2"
:
"B"
,
"3"
:
"C"
,
"4"
:
"D"
,
"5"
:
"E"
}
...
...
@@ -57,6 +67,12 @@ class ARCEasy(HFTask, MultipleChoiceTask):
def
doc_to_text
(
self
,
doc
):
return
doc
[
"query"
]
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"query"
]
class
ARCChallenge
(
ARCEasy
):
DATASET_PATH
=
"ai2_arc"
...
...
lm_eval/tasks/arithmetic.py
View file @
11f614b0
...
...
@@ -7,13 +7,10 @@ problem in natural language.
Homepage: https://github.com/openai/gpt-3/tree/master/data
"""
import
abc
import
json
import
os
from
collections
import
namedtuple
import
inspect
import
lm_eval.datasets.arithmetic.arithmetic
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
from
best_download
import
download_file
_CITATION
=
"""
...
...
@@ -31,33 +28,9 @@ _CITATION = """
"""
ArithmeticDoc
=
namedtuple
(
'ArithmeticDoc'
,
[
'context'
,
'completion'
])
class
Arithmetic
(
Task
):
VERSION
=
0
directory
=
'data/arithmetic/'
def
__init__
(
self
):
super
().
__init__
()
def
download
(
self
):
file_name
,
checksum
=
self
.
get_file_download_info
()
url
=
'https://raw.githubusercontent.com/openai/gpt-3/master/data/'
+
file_name
if
not
os
.
path
.
exists
(
self
.
directory
):
os
.
makedirs
(
self
.
directory
)
download_file
(
url
,
local_file
=
self
.
directory
+
file_name
,
expected_checksum
=
checksum
)
self
.
set_docs
()
@
abc
.
abstractmethod
def
get_file_download_info
(
self
):
"""returns a tuple of (file_name, checksum)"""
pass
def
set_docs
(
self
):
file_name
,
_
=
self
.
get_file_download_info
()
jsons
=
open
(
self
.
directory
+
file_name
,
'r'
)
self
.
_docs
=
[
self
.
load_doc
(
json
.
loads
(
line
))
for
line
in
jsons
]
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
arithmetic
.
arithmetic
)
def
has_training_docs
(
self
):
return
False
...
...
@@ -72,25 +45,25 @@ class Arithmetic(Task):
return
NotImplemented
def
validation_docs
(
self
):
return
self
.
_docs
return
self
.
dataset
[
"validation"
]
def
test_docs
(
self
):
return
NotImplemented
def
doc_to_text
(
self
,
doc
):
return
doc
.
context
return
doc
[
"context"
]
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"context"
]
def
doc_to_target
(
self
,
doc
):
return
doc
.
completion
return
doc
[
"
completion
"
]
def
load_doc
(
self
,
doc_json
):
return
ArithmeticDoc
(
context
=
doc_json
[
'context'
].
strip
()
.
replace
(
'
\n\n
'
,
'
\n
'
)
.
replace
(
'Q:'
,
'Question:'
)
.
replace
(
'A:'
,
'Answer:'
),
completion
=
doc_json
[
'completion'
])
def
construct_requests
(
self
,
doc
,
ctx
):
ll
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
doc
.
completion
)
ll
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
doc
[
"
completion
"
]
)
return
is_prediction
def
process_results
(
self
,
doc
,
results
):
...
...
@@ -111,41 +84,40 @@ class Arithmetic(Task):
class
Arithmetic2DPlus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'two_digit_addition.jsonl'
,
'75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
DATASET_NAME
=
"arithmetic_2da"
class
Arithmetic2DMinus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'two_digit_subtraction.jsonl'
,
'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
DATASET_NAME
=
"arithmetic_2ds"
class
Arithmetic3DPlus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'three_digit_addition.jsonl'
,
'124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
DATASET_NAME
=
"arithmetic_3da"
class
Arithmetic3DMinus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'three_digit_subtraction.jsonl'
,
'7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
DATASET_NAME
=
"arithmetic_3ds"
class
Arithmetic4DPlus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'four_digit_addition.jsonl'
,
'459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
DATASET_NAME
=
"arithmetic_4da"
class
Arithmetic4DMinus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'four_digit_subtraction.jsonl'
,
'0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
DATASET_NAME
=
"arithmetic_4ds"
class
Arithmetic5DPlus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'five_digit_addition.jsonl'
,
'30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
DATASET_NAME
=
"arithmetic_5da"
class
Arithmetic5DMinus
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'five_digit_subtraction.jsonl'
,
'8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
DATASET_NAME
=
"arithmetic_5ds"
class
Arithmetic2DMultiplication
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'two_digit_multiplication.jsonl'
,
'5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
DATASET_NAME
=
"arithmetic_2dm"
class
Arithmetic1DComposite
(
Arithmetic
):
def
get_file_download_info
(
self
):
return
'single_digit_three_ops.jsonl'
,
'08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'
DATASET_NAME
=
"arithmetic_1dc"
lm_eval/tasks/asdiv.py
View file @
11f614b0
...
...
@@ -14,15 +14,10 @@ NOTE: We currently ignore formulas for answer generation.
Homepage: https://github.com/chaochun/nlu-asdiv-dataset
"""
from
lm_eval.base
import
Task
from
pathlib
import
Path
from
best_download
import
download_file
import
xml.etree.ElementTree
as
ET
from
lm_eval.base
import
rf
from
lm_eval.metrics
import
mean
,
perplexity
import
numpy
as
np
from
zipfile
import
ZipFile
import
os
import
inspect
import
lm_eval.datasets.asdiv.asdiv
from
lm_eval.base
import
rf
,
Task
from
lm_eval.metrics
import
mean
_CITATION
=
"""
...
...
@@ -39,39 +34,11 @@ _CITATION = """
class
Asdiv
(
Task
):
VERSION
=
0
DATASET_PATH
=
Path
(
"data/asdiv"
)
def
download
(
self
):
if
self
.
DATASET_PATH
.
exists
():
return
Path
.
mkdir
(
self
.
DATASET_PATH
,
parents
=
True
)
url
=
"https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccfa5053194b25732534696b50.zip"
checksum
=
"8f1fe4f6d5f170ec1e24ab78c244153c14c568b1bb2b1dad0324e71f37939a2d"
zip_path
=
self
.
DATASET_PATH
/
"55790e5270bb91ccfa5053194b25732534696b50.zip"
download_file
(
url
,
local_file
=
str
(
zip_path
),
expected_checksum
=
checksum
)
with
ZipFile
(
zip_path
,
"r"
)
as
zip
:
zip
.
extractall
(
self
.
DATASET_PATH
)
os
.
remove
(
zip_path
)
def
_convert_standard
(
self
,
problem
):
#TODO: include solution-type and formula
out_doc
=
{
"question"
:
problem
.
find
(
'Question'
).
text
,
"body"
:
problem
.
find
(
'Body'
).
text
,
"answer"
:
problem
.
find
(
'Answer'
).
text
}
return
out_doc
def
load_docs
(
self
,
textfilename
,
tfds
=
False
):
tree
=
ET
.
parse
(
textfilename
)
root
=
tree
.
getroot
()
for
pid
,
problem
in
enumerate
(
root
.
iter
(
'Problem'
)):
out_doc
=
self
.
_convert_standard
(
problem
)
yield
out_doc
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
asdiv
.
asdiv
)
def
has_training_docs
(
self
):
return
False
def
has_validation_docs
(
self
):
return
True
...
...
@@ -81,13 +48,12 @@ class Asdiv(Task):
def
training_docs
(
self
):
raise
NotImplementedError
(
"This dataset has no training docs"
)
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
test_docs
(
self
):
raise
NotImplementedError
(
"This dataset has no test docs"
)
def
validation_docs
(
self
):
data_xml_path
=
self
.
DATASET_PATH
/
"nlu-asdiv-dataset-55790e5270bb91ccfa5053194b25732534696b50/dataset/ASDiv.xml"
return
self
.
load_docs
(
data_xml_path
)
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
num_fewshot
==
0
,
"ASDiv is intended only for the zero-shot setting."
return
super
().
fewshot_context
(
...
...
@@ -101,6 +67,12 @@ class Asdiv(Task):
# TODO: add solution-type
return
doc
[
'body'
]
+
'
\n
'
+
'Question:'
+
doc
[
'question'
]
+
'
\n
'
+
'Answer:'
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
'body'
]
+
" "
+
doc
[
'question'
]
def
doc_to_target
(
self
,
doc
):
# TODO: add formula
...
...
lm_eval/tasks/blimp.py
View file @
11f614b0
...
...
@@ -10,9 +10,8 @@ grammars.
Homepage: https://github.com/alexwarstadt/blimp
"""
from
lm_eval.base
import
rf
from
lm_eval.base
import
rf
,
Task
from
lm_eval.metrics
import
mean
from
.common
import
HFTask
_CITATION
=
"""
...
...
@@ -32,19 +31,24 @@ _CITATION = """
"""
class
BlimpTask
(
HF
Task
):
class
BlimpTask
(
Task
):
VERSION
=
0
DATASET_PATH
=
"blimp"
def
download
(
self
):
super
().
download
()
def
has_training_docs
(
self
):
return
False
def
has_validation_docs
(
self
):
return
True
def
has_test_docs
(
self
):
return
False
def
validation_docs
(
self
):
# The HF dataset only contains a "train" dataset, but the harness expects a "validation"
# dataset. Let's use the training dataset, on the assumption that the model wasn't actually
# trained on this data.
self
.
data
[
"validation"
]
=
self
.
data
[
"train"
]
del
self
.
data
[
"train"
]
return
self
.
dataset
[
"train"
]
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
num_fewshot
==
0
...
...
@@ -64,6 +68,12 @@ class BlimpTask(HFTask):
# this method is invoked by tests only
return
""
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"sentence_good"
]
+
" "
+
doc
[
"sentence_bad"
]
def
doc_to_target
(
self
,
doc
):
# this method is invoked by tests only
return
""
...
...
lm_eval/tasks/cbt.py
View file @
11f614b0
...
...
@@ -13,9 +13,8 @@ used by the Recurrent Language Models described in the paper. See section 4.4.
Homepage: https://github.com/facebookresearch/ParlAI/tree/main/parlai/tasks/cbt
"""
import
numpy
as
np
from
lm_eval.base
import
rf
from
lm_eval.base
import
rf
,
Task
from
lm_eval.metrics
import
mean
from
.common
import
HFTask
_CITATION
=
"""
...
...
@@ -30,11 +29,30 @@ _CITATION = """
"""
class
CBTBase
(
HF
Task
):
class
CBTBase
(
Task
):
VERSION
=
0
DATASET_PATH
=
"cbt"
DATASET_NAME
=
None
def
has_training_docs
(
self
):
return
True
def
has_validation_docs
(
self
):
return
True
def
has_test_docs
(
self
):
return
True
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
test_docs
(
self
):
return
self
.
dataset
[
"test"
]
def
detokenize
(
self
,
text
):
text
=
text
.
replace
(
" '"
,
"'"
)
...
...
@@ -57,6 +75,13 @@ class CBTBase(HFTask):
text
=
"Passage: "
+
passage
+
"
\n
Question: "
+
doc
[
"question"
]
return
self
.
detokenize
(
text
)
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
passage
=
" "
.
join
(
doc
[
"sentences"
])
return
passage
def
doc_to_target
(
self
,
doc
):
return
""
...
...
lm_eval/tasks/common.py
deleted
100644 → 0
View file @
0a6a9b7e
import
datasets
from
..base
import
Task
class
HFTask
(
Task
):
DATASET_PATH
=
None
DATASET_NAME
=
None
def
__init__
(
self
):
self
.
data
=
None
super
().
__init__
()
def
download
(
self
):
self
.
data
=
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)
def
has_training_docs
(
self
):
"""Whether the task has a training set"""
return
True
if
"train"
in
self
.
data
.
keys
()
else
False
def
has_validation_docs
(
self
):
"""Whether the task has a validation set"""
return
True
if
"validation"
in
self
.
data
.
keys
()
else
False
def
has_test_docs
(
self
):
"""Whether the task has a test set"""
return
True
if
"test"
in
self
.
data
.
keys
()
else
False
def
_convert_standard
(
self
,
doc
):
return
doc
def
training_docs
(
self
):
# Cache training for faster few-shot.
# If data is too large to fit in memory, override this method.
if
self
.
has_training_docs
():
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
map
(
self
.
_convert_standard
,
self
.
data
[
"train"
]))
return
self
.
_training_docs
def
validation_docs
(
self
):
if
self
.
has_validation_docs
():
return
map
(
self
.
_convert_standard
,
self
.
data
[
"validation"
])
def
test_docs
(
self
):
if
self
.
has_test_docs
():
return
map
(
self
.
_convert_standard
,
self
.
data
[
"test"
])
def
yesno
(
x
):
if
x
:
return
'yes'
else
:
return
'no'
lm_eval/tasks/coqa.py
View file @
11f614b0
...
...
@@ -9,13 +9,11 @@ appear in a conversation.
Homepage: https://stanfordnlp.github.io/coqa/
"""
import
os
import
json
import
inspect
import
transformers.data.metrics.squad_metrics
as
squad_metrics
import
lm_eval.datasets.coqa.coqa
from
lm_eval.base
import
Task
,
rf
,
mean
from
..utils
import
sh
from
itertools
import
zip_longest
from
best_download
import
download_file
_CITATION
=
"""
...
...
@@ -32,15 +30,8 @@ _CITATION = """
class
CoQA
(
Task
):
VERSION
=
1
def
download
(
self
):
coqa_train_filepath
=
'data/coqa/coqa-train-v1.0.json'
coqa_dev_filepath
=
'data/coqa/coqa-dev-v1.0.json'
sh
(
"""mkdir -p data/coqa"""
)
download_file
(
"http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json"
,
local_file
=
coqa_train_filepath
,
expected_checksum
=
"b0fdb2bc1bd38dd3ca2ce5fa2ac3e02c6288ac914f241ac409a655ffb6619fa6"
)
download_file
(
"http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json"
,
local_file
=
coqa_dev_filepath
,
expected_checksum
=
"dfa367a9733ce53222918d0231d9b3bedc2b8ee831a2845f62dfc70701f2540a"
)
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
coqa
.
coqa
)
DATASET_NAME
=
None
def
has_training_docs
(
self
):
return
True
...
...
@@ -52,10 +43,10 @@ class CoQA(Task):
return
False
def
training_docs
(
self
):
return
json
.
load
(
open
(
'data/coqa/coqa-train-v1.0.json'
))[
'data'
]
return
self
.
dataset
[
"train"
]
def
validation_docs
(
self
):
return
json
.
load
(
open
(
'data/coqa/coqa-dev-v1.0.json'
))[
'data'
]
return
self
.
dataset
[
"validation"
]
def
test_docs
(
self
):
pass
...
...
@@ -64,23 +55,29 @@ class CoQA(Task):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text
=
doc
[
"story"
]
+
'
\n\n
'
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
]
,
doc
[
"answers
"
][:
-
1
]):
# omit target answer ai
question
=
f
"Q:
{
q
[
'input_text'
]
}
"
+
'
\n\n
'
answer
=
f
"A:
{
a
[
'input_text'
]
}
"
+
'
\n\n
'
if
a
is
not
None
else
"A:"
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
]
[
"input_text"
],
doc
[
"answers"
][
"input_text
"
][:
-
1
]):
# omit target answer ai
question
=
f
"Q:
{
q
}
\n\n
"
answer
=
f
"A:
{
a
}
\n\n
"
if
a
is
not
None
else
"A:"
doc_text
+=
question
+
answer
return
doc_text
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"story"
]
+
" "
+
"
\n
"
.
join
(
doc
[
"questions"
][
"input_text"
])
@
classmethod
def
get_answers
(
cls
,
doc
,
turn_id
):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers
=
[]
answer_forturn
=
doc
[
"answers"
][
turn_id
-
1
][
"input_text"
]
answer_forturn
=
doc
[
"answers"
][
"input_text"
][
turn_id
-
1
]
answers
.
append
(
answer_forturn
)
additional_answers
=
doc
.
get
(
"additional_answers"
)
if
additional_answers
:
for
key
in
additional_answers
:
additional_answer_for_turn
=
additional_answers
[
key
][
turn_id
-
1
][
"input_text"
]
additional_answer_for_turn
=
additional_answers
[
key
][
"input_text"
][
turn_id
-
1
]
if
additional_answer_for_turn
.
lower
()
not
in
map
(
str
.
lower
,
answers
):
answers
.
append
(
additional_answer_for_turn
)
return
answers
...
...
@@ -120,8 +117,8 @@ class CoQA(Task):
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
# Default to prediction of last turn.
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
])
raw_text
=
doc
[
'answers'
][
turnid
-
1
][
"input_text"
]
turnid
=
len
(
doc
[
"questions"
]
[
"input_text"
]
)
raw_text
=
doc
[
'answers'
][
"input_text"
]
[
turnid
-
1
]
return
" "
+
raw_text
def
construct_requests
(
self
,
doc
,
ctx
):
...
...
@@ -148,7 +145,7 @@ class CoQA(Task):
:param results:
The results of the requests created in construct_requests.
"""
turn_id
=
len
(
doc
[
"questions"
])
turn_id
=
len
(
doc
[
"questions"
]
[
"input_text"
]
)
gold_list
=
self
.
get_answers
(
doc
,
turn_id
)
pred
=
results
[
0
].
strip
().
split
(
'
\n
'
)[
0
]
...
...
lm_eval/tasks/drop.py
View file @
11f614b0
...
...
@@ -12,16 +12,14 @@ Homepage: https://allenai.org/data/drop
Acknowledgement: This implementation is based on the official evaluation for `DROP`:
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
"""
import
json
import
inspect
import
numpy
as
np
import
re
import
string
from
best_download
import
download_file
import
lm_eval.datasets.drop.drop
from
scipy.optimize
import
linear_sum_assignment
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
from
pathlib
import
Path
from
zipfile
import
ZipFile
_CITATION
=
"""
...
...
@@ -41,18 +39,8 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class
DROP
(
Task
):
VERSION
=
1
DATASET_PATH
=
Path
(
"data/drop"
)
def
download
(
self
):
if
self
.
DATASET_PATH
.
exists
():
return
Path
.
mkdir
(
self
.
DATASET_PATH
,
parents
=
True
)
url
=
"https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip"
checksum
=
"39d2278a29fd729de301b111a45f434c24834f40df8f4ff116d864589e3249d6"
zip_path
=
self
.
DATASET_PATH
/
"drop_dataset.zip"
download_file
(
url
,
local_file
=
str
(
zip_path
),
expected_checksum
=
checksum
)
with
ZipFile
(
zip_path
,
"r"
)
as
zip
:
zip
.
extractall
(
self
.
DATASET_PATH
)
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
drop
.
drop
)
DATASET_NAME
=
None
def
has_training_docs
(
self
):
return
True
...
...
@@ -63,29 +51,46 @@ class DROP(Task):
def
has_test_docs
(
self
):
return
False
def
_load_docs
(
self
,
docs
):
for
doc
in
docs
:
for
qa
in
doc
[
"qa_pairs"
]:
yield
{
"id"
:
qa
[
"query_id"
],
"passage"
:
doc
[
"passage"
],
"question"
:
qa
[
"question"
],
"answers"
:
self
.
get_answers
(
qa
),
}
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
map
(
self
.
_process_doc
,
self
.
dataset
[
"train"
]))
return
self
.
_training_docs
def
validation_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"validation"
])
def
_process_doc
(
self
,
doc
):
return
{
"id"
:
doc
[
"query_id"
],
"passage"
:
doc
[
"passage"
],
"question"
:
doc
[
"question"
],
"answers"
:
self
.
get_answers
(
doc
),
}
@
classmethod
def
get_answers
(
cls
,
qa
):
def
_flatten_validated_answers
(
validated_answers
):
""" Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas
=
[]
for
i
in
range
(
len
(
validated_answers
[
"number"
])):
vas
.
append
({
"number"
:
validated_answers
[
"number"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
})
return
vas
answers
=
[]
answers_set
=
set
()
candidates
=
[
qa
[
"answer"
]]
+
qa
.
get
(
"validated_answers"
,
[])
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
])
for
candidate
in
candidates
:
answer
=
cls
.
parse_answer
(
candidate
)
if
answer
in
answers_set
:
continue
answers_set
.
add
(
answer
)
answers
.
append
(
answer
)
return
answers
@
classmethod
...
...
@@ -99,17 +104,15 @@ class DROP(Task):
answer
[
"date"
][
"month"
],
answer
[
"date"
][
"year"
]]).
strip
(),)
def
training_docs
(
self
):
docs
=
json
.
load
(
open
(
self
.
DATASET_PATH
/
"drop_dataset"
/
"drop_dataset_train.json"
))
return
self
.
_load_docs
([
docs
[
k
]
for
k
in
docs
.
keys
()])
def
validation_docs
(
self
):
docs
=
json
.
load
(
open
(
self
.
DATASET_PATH
/
"drop_dataset"
/
"drop_dataset_dev.json"
))
return
self
.
_load_docs
([
docs
[
k
]
for
k
in
docs
.
keys
()])
def
doc_to_text
(
self
,
doc
):
return
f
"Passage:
{
doc
[
'passage'
]
}
\n
Question:
{
doc
[
'question'
]
}
\n
Answer:"
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
'passage'
]
+
" "
+
doc
[
'question'
]
def
doc_to_target
(
self
,
doc
):
return
" "
+
", "
.
join
(
doc
[
"answers"
][
0
])
...
...
lm_eval/tasks/glue.py
View file @
11f614b0
...
...
@@ -14,10 +14,9 @@ respect to a wide range of linguistic phenomena found in natural language.
Homepage: https://gluebenchmark.com/
"""
import
numpy
as
np
from
lm_eval.base
import
rf
from
..metrics
import
mean
,
matthews_corrcoef
,
f1_score
from
.
common
import
HFTask
,
yesno
from
..utils
import
general_detokenize
from
lm_eval.base
import
rf
,
Task
from
lm_eval.metrics
import
mean
,
matthews_corrcoef
,
f1_score
,
yesno
from
lm_eval.utils
import
general_detokenize
# TODO(jon-tow): Add citations for the individual datasets/tasks that make up GLUE.
...
...
@@ -46,7 +45,7 @@ _CITATION = """
# Single-Sentence Tasks
class
CoLA
(
HF
Task
):
class
CoLA
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"cola"
...
...
@@ -60,9 +59,23 @@ class CoLA(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: Does this sentence make sense?
\n
Answer:"
.
format
(
doc
[
"sentence"
])
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
"sentence"
]
def
doc_to_target
(
self
,
doc
):
return
" {}"
.
format
({
1
:
"yes"
,
0
:
"no"
}[
doc
[
"label"
]])
...
...
@@ -90,7 +103,7 @@ class CoLA(HFTask):
}
class
SST
(
HF
Task
):
class
SST
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"sst2"
...
...
@@ -104,6 +117,14 @@ class SST(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: Is this sentence positive or negative?
\n
Answer:"
.
format
(
general_detokenize
(
doc
[
"sentence"
]),
...
...
@@ -139,7 +160,7 @@ class SST(HFTask):
# Inference Tasks
class
MNLI
(
HF
Task
):
class
MNLI
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"mnli"
...
...
@@ -153,13 +174,18 @@ class MNLI(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
if
self
.
has_validation_docs
():
return
self
.
data
[
"validation_matched"
]
return
self
.
data
set
[
"validation_matched"
]
def
test_docs
(
self
):
if
self
.
has_test_docs
():
return
self
.
data
[
"test_matched"
]
return
self
.
data
set
[
"test_matched"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: {} True, False or Neither?
\n
Answer:"
.
format
(
...
...
@@ -202,14 +228,14 @@ class MNLIMismatched(MNLI):
def
validation_docs
(
self
):
if
self
.
has_validation_docs
():
return
self
.
data
[
"validation_mismatched"
]
return
self
.
data
set
[
"validation_mismatched"
]
def
test_docs
(
self
):
if
self
.
has_test_docs
():
return
self
.
data
[
"test_mismatched"
]
return
self
.
data
set
[
"test_mismatched"
]
class
QNLI
(
HF
Task
):
class
QNLI
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"qnli"
...
...
@@ -223,6 +249,14 @@ class QNLI(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
{}
\n
Question: Does this response answer the question?
\n
Answer:"
.
format
(
doc
[
"question"
],
...
...
@@ -258,7 +292,7 @@ class QNLI(HFTask):
}
class
WNLI
(
HF
Task
):
class
WNLI
(
Task
):
VERSION
=
1
DATASET_PATH
=
"glue"
DATASET_NAME
=
"wnli"
...
...
@@ -272,6 +306,14 @@ class WNLI(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: {} True or False?
\n
Answer:"
.
format
(
doc
[
"sentence1"
],
...
...
@@ -307,7 +349,7 @@ class WNLI(HFTask):
}
class
RTE
(
HF
Task
):
class
RTE
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"rte"
...
...
@@ -321,6 +363,14 @@ class RTE(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: {} True or False?
\n
Answer:"
.
format
(
doc
[
"sentence1"
],
...
...
@@ -359,7 +409,7 @@ class RTE(HFTask):
# Similarity and Paraphrase Tasks
class
MRPC
(
HF
Task
):
class
MRPC
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"mrpc"
...
...
@@ -373,6 +423,14 @@ class MRPC(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"Sentence 1: {}
\n
Sentence 2: {}
\n
Question: Do both sentences mean the same thing?
\n
Answer:"
.
format
(
general_detokenize
(
doc
[
"sentence1"
]),
...
...
@@ -409,7 +467,7 @@ class MRPC(HFTask):
}
class
QQP
(
HF
Task
):
class
QQP
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"qqp"
...
...
@@ -423,6 +481,14 @@ class QQP(HFTask):
def
has_test_docs
(
self
):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"Question 1: {}
\n
Question 2: {}
\n
Question: Do both questions ask the same thing?
\n
Answer:"
.
format
(
doc
[
"question1"
],
...
...
@@ -459,7 +525,7 @@ class QQP(HFTask):
}
class
STSB
(
HF
Task
):
class
STSB
(
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"stsb"
...
...
@@ -473,6 +539,17 @@ class STSB(HFTask):
def
has_test_docs
(
self
):
return
True
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
test_docs
(
self
):
return
self
.
dataset
[
"test"
]
def
doc_to_text
(
self
,
doc
):
return
"sentence 1: {}
\n
sentence 2: {}
\n
Answer:"
.
format
(
doc
[
"sentence1"
],
...
...
lm_eval/tasks/gsm8k.py
View file @
11f614b0
...
...
@@ -16,10 +16,9 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math
"""
import
json
import
inspect
import
re
from
best_download
import
download_file
import
lm_eval.datasets.gsm8k.gsm8k
from
pathlib
import
Path
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
...
...
@@ -43,21 +42,8 @@ INVALID_ANS = "[invalid]"
class
GradeSchoolMath8K
(
Task
):
VERSION
=
0
DATASET_PATH
=
Path
(
'data/gsm8k'
)
def
download
(
self
):
if
self
.
DATASET_PATH
.
exists
():
return
Path
.
mkdir
(
self
.
DATASET_PATH
,
parents
=
True
)
base_url
=
"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data"
splits
=
[
{
"name"
:
"train"
,
"checksum"
:
"17f347dc51477c50d4efb83959dbb7c56297aba886e5544ee2aaed3024813465"
},
{
"name"
:
"test"
,
"checksum"
:
"3730d312f6e3440559ace48831e51066acaca737f6eabec99bccb9e4b3c39d14"
},
]
for
split
in
splits
:
file
=
self
.
DATASET_PATH
/
f
"
{
split
[
'name'
]
}
.jsonl"
url
=
f
"
{
base_url
}
/
{
split
[
'name'
]
}
.jsonl"
download_file
(
url
,
local_file
=
str
(
file
),
expected_checksum
=
split
[
"checksum"
])
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
gsm8k
.
gsm8k
)
DATASET_NAME
=
None
def
has_training_docs
(
self
):
return
True
...
...
@@ -68,17 +54,14 @@ class GradeSchoolMath8K(Task):
def
has_test_docs
(
self
):
return
True
def
_load_docs
(
self
,
file
):
return
(
json
.
loads
(
line
)
for
line
in
open
(
file
).
read
().
splitlines
())
def
training_docs
(
self
):
return
self
.
_load_docs
(
self
.
DATASET_PATH
/
"train.jsonl"
)
return
self
.
dataset
[
"train"
]
def
validation_docs
(
self
):
raise
NotImplementedError
def
test_docs
(
self
):
return
self
.
_load_docs
(
self
.
DATASET_PATH
/
"test.jsonl"
)
return
self
.
dataset
[
"test"
]
def
doc_to_text
(
self
,
doc
):
return
"Question: "
+
doc
[
'question'
]
+
'
\n
Answer:'
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment