Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
121b7096
Commit
121b7096
authored
May 02, 2022
by
Fabrizio Milo
Browse files
add pre-commit
parent
7a038118
Changes
120
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
544 additions
and
396 deletions
+544
-396
lm_eval/datasets/truthfulqa/truthfulqa.py
lm_eval/datasets/truthfulqa/truthfulqa.py
+36
-26
lm_eval/datasets/unscramble/dataset_infos.json
lm_eval/datasets/unscramble/dataset_infos.json
+1
-1
lm_eval/datasets/unscramble/unscramble.py
lm_eval/datasets/unscramble/unscramble.py
+3
-2
lm_eval/datasets/wikitext/dataset_infos.json
lm_eval/datasets/wikitext/dataset_infos.json
+1
-1
lm_eval/datasets/wikitext/wikitext.py
lm_eval/datasets/wikitext/wikitext.py
+55
-30
lm_eval/decontamination/archiver.py
lm_eval/decontamination/archiver.py
+47
-33
lm_eval/decontamination/decontaminate.py
lm_eval/decontamination/decontaminate.py
+32
-17
lm_eval/decontamination/janitor.py
lm_eval/decontamination/janitor.py
+45
-33
lm_eval/evaluator.py
lm_eval/evaluator.py
+80
-44
lm_eval/metrics.py
lm_eval/metrics.py
+15
-10
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+42
-19
lm_eval/models/gpt3.py
lm_eval/models/gpt3.py
+36
-21
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+22
-35
lm_eval/tasks/anli.py
lm_eval/tasks/anli.py
+23
-24
lm_eval/tasks/arithmetic.py
lm_eval/tasks/arithmetic.py
+4
-8
lm_eval/tasks/asdiv.py
lm_eval/tasks/asdiv.py
+12
-19
lm_eval/tasks/blimp.py
lm_eval/tasks/blimp.py
+10
-4
lm_eval/tasks/cbt.py
lm_eval/tasks/cbt.py
+8
-12
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+37
-28
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+35
-29
No files found.
lm_eval/datasets/truthfulqa/truthfulqa.py
View file @
121b7096
...
...
@@ -65,13 +65,14 @@ class TruthfulqaConfig(datasets.BuilderConfig):
class
Truthfulqa
(
datasets
.
GeneratorBasedBuilder
):
"""TruthfulQA is a benchmark to measure whether a language model is truthful in
generating answers to questions."""
generating answers to questions."""
BUILDER_CONFIGS
=
[
TruthfulqaConfig
(
name
=
"multiple_choice"
,
url
=
"https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json"
,
features
=
datasets
.
Features
({
features
=
datasets
.
Features
(
{
"question"
:
datasets
.
Value
(
"string"
),
"mc1_targets"
:
{
"choices"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)),
...
...
@@ -80,23 +81,30 @@ generating answers to questions."""
"mc2_targets"
:
{
"choices"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)),
"labels"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"int32"
)),
},
}
}
),
description
=
"The multiple choice TruthfulQA task"
),
description
=
"The multiple choice TruthfulQA task"
,
),
TruthfulqaConfig
(
name
=
"generation"
,
url
=
"https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv"
,
features
=
datasets
.
Features
({
features
=
datasets
.
Features
(
{
"category"
:
datasets
.
Value
(
"string"
),
"question"
:
datasets
.
Value
(
"string"
),
"best_answer"
:
datasets
.
Value
(
"string"
),
"correct_answers"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)),
"incorrect_answers"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)),
"correct_answers"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)
),
"incorrect_answers"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)
),
"source"
:
datasets
.
Value
(
"string"
),
}),
description
=
"The generative TruthfulQA task"
)
}
),
description
=
"The generative TruthfulQA task"
,
),
]
def
_info
(
self
):
...
...
@@ -138,15 +146,15 @@ generating answers to questions."""
"mc2_targets"
:
{
"choices"
:
row
[
"mc2_targets"
].
keys
(),
"labels"
:
row
[
"mc2_targets"
].
values
(),
}
}
,
}
else
:
# Generation data is in a `CSV` file.
with
open
(
filepath
,
newline
=
''
)
as
f
:
with
open
(
filepath
,
newline
=
""
)
as
f
:
contents
=
csv
.
DictReader
(
f
)
for
key
,
row
in
enumerate
(
contents
):
# Ensure that references exist.
if
not
row
[
'
Correct Answers
'
]
or
not
row
[
'
Incorrect Answers
'
]:
if
not
row
[
"
Correct Answers
"
]
or
not
row
[
"
Incorrect Answers
"
]:
continue
yield
key
,
{
"category"
:
row
[
"Category"
],
...
...
@@ -154,6 +162,8 @@ generating answers to questions."""
"best_answer"
:
row
[
"Best Answer"
],
# split on ";"
"correct_answers"
:
row
[
"Correct Answers"
].
strip
().
split
(
";"
),
"incorrect_answers"
:
row
[
"Incorrect Answers"
].
strip
().
split
(
";"
),
"incorrect_answers"
:
row
[
"Incorrect Answers"
]
.
strip
()
.
split
(
";"
),
"source"
:
row
[
"Source"
],
}
lm_eval/datasets/unscramble/dataset_infos.json
View file @
121b7096
lm_eval/datasets/unscramble/unscramble.py
View file @
121b7096
...
...
@@ -64,8 +64,9 @@ class Unscramble(datasets.GeneratorBasedBuilder):
VERSION
=
datasets
.
Version
(
"0.0.1"
)
BUILDER_CONFIGS
=
[
datasets
.
BuilderConfig
(
name
=
name
,
version
=
version
,
description
=
_DESCRIPTIONS
[
name
])
datasets
.
BuilderConfig
(
name
=
name
,
version
=
version
,
description
=
_DESCRIPTIONS
[
name
]
)
for
name
,
version
in
zip
(
_NAMES
,
[
VERSION
]
*
len
(
_NAMES
))
]
...
...
lm_eval/datasets/wikitext/dataset_infos.json
View file @
121b7096
lm_eval/datasets/wikitext/wikitext.py
View file @
121b7096
...
...
@@ -123,86 +123,111 @@ class Wikitext(datasets.GeneratorBasedBuilder):
return
[
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.tokens"
),
"split"
:
"test"
},
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.tokens"
),
"split"
:
"test"
,
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.tokens"
),
"split"
:
"train"
},
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.tokens"
),
"split"
:
"train"
,
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
VALIDATION
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.tokens"
),
"split"
:
"valid"
},
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.tokens"
),
"split"
:
"valid"
,
},
),
]
else
:
if
self
.
config
.
name
==
"wikitext-103-raw-v1"
:
data_file
=
dl_manager
.
download_and_extract
(
self
.
config
.
data_url
)
data_file
=
dl_manager
.
download_and_extract
(
self
.
config
.
data_url
)
data_dir
=
os
.
path
.
join
(
data_file
,
"wikitext-103-raw"
)
return
[
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.raw"
),
"split"
:
"test"
},
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.raw"
),
"split"
:
"test"
,
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.raw"
),
"split"
:
"train"
},
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.raw"
),
"split"
:
"train"
,
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
VALIDATION
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.raw"
),
"split"
:
"valid"
},
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.raw"
),
"split"
:
"valid"
,
},
),
]
else
:
if
self
.
config
.
name
==
"wikitext-2-raw-v1"
:
data_file
=
dl_manager
.
download_and_extract
(
self
.
config
.
data_url
)
data_file
=
dl_manager
.
download_and_extract
(
self
.
config
.
data_url
)
data_dir
=
os
.
path
.
join
(
data_file
,
"wikitext-2-raw"
)
return
[
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.raw"
),
"split"
:
"test"
},
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.raw"
),
"split"
:
"test"
,
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.raw"
),
"split"
:
"train"
},
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.raw"
),
"split"
:
"train"
,
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
VALIDATION
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.raw"
),
"split"
:
"valid"
},
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.raw"
),
"split"
:
"valid"
,
},
),
]
else
:
if
self
.
config
.
name
==
"wikitext-2-v1"
:
data_file
=
dl_manager
.
download_and_extract
(
self
.
config
.
data_url
)
self
.
config
.
data_url
)
data_dir
=
os
.
path
.
join
(
data_file
,
"wikitext-2"
)
return
[
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.tokens"
),
"split"
:
"test"
},
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.tokens"
),
"split"
:
"test"
,
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.tokens"
),
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.tokens"
),
"split"
:
"train"
,
},
),
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
VALIDATION
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.tokens"
),
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.tokens"
),
"split"
:
"valid"
,
},
),
...
...
@@ -216,12 +241,12 @@ class Wikitext(datasets.GeneratorBasedBuilder):
data
=
f
.
read
().
split
(
"
\n
"
)
for
line
in
data
:
rline
=
line
.
replace
(
"= = ="
,
"==="
).
replace
(
"= ="
,
"=="
).
strip
()
if
rline
.
startswith
(
'
=
'
)
and
rline
.
strip
().
endswith
(
'
=
'
):
page
=
'
\n
'
.
join
(
ret
)
if
rline
.
startswith
(
"
=
"
)
and
rline
.
strip
().
endswith
(
"
=
"
):
page
=
"
\n
"
.
join
(
ret
)
if
page
.
strip
():
yield
key
,
{
"page"
:
page
}
key
+=
1
ret
=
[]
ret
.
append
(
line
)
page
=
'
\n
'
.
join
(
ret
)
page
=
"
\n
"
.
join
(
ret
)
yield
key
,
{
"page"
:
page
}
lm_eval/decontamination/archiver.py
View file @
121b7096
...
...
@@ -8,12 +8,14 @@ import mmap
import
tqdm
from
pathlib
import
Path
def
json_serial
(
obj
):
"""JSON serializer for objects not serializable by default json code"""
if
isinstance
(
obj
,
(
datetime
.
datetime
,)):
return
obj
.
isoformat
()
raise
TypeError
(
"Type %s not serializable"
%
type
(
obj
))
raise
TypeError
(
"Type %s not serializable"
%
type
(
obj
))
# Modified version of lm_dataformat Archive for single file.
class
Archive
:
...
...
@@ -22,25 +24,31 @@ class Archive:
dir_name
=
os
.
path
.
dirname
(
file_path
)
if
dir_name
:
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
self
.
fh
=
open
(
self
.
file_path
,
'
wb
'
)
self
.
fh
=
open
(
self
.
file_path
,
"
wb
"
)
self
.
cctx
=
zstandard
.
ZstdCompressor
(
level
=
compression_level
)
self
.
compressor
=
self
.
cctx
.
stream_writer
(
self
.
fh
)
def
add_data
(
self
,
data
,
meta
=
{}):
self
.
compressor
.
write
(
json
.
dumps
({
'text'
:
data
,
'meta'
:
meta
},
default
=
json_serial
).
encode
(
'UTF-8'
)
+
b
'
\n
'
)
self
.
compressor
.
write
(
json
.
dumps
({
"text"
:
data
,
"meta"
:
meta
},
default
=
json_serial
).
encode
(
"UTF-8"
)
+
b
"
\n
"
)
def
commit
(
self
):
self
.
compressor
.
flush
(
zstandard
.
FLUSH_FRAME
)
self
.
fh
.
flush
()
self
.
fh
.
close
()
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class
Reader
:
def
__init__
(
self
):
pass
def
read
(
self
,
file
,
get_meta
=
False
,
autojoin_paragraphs
=
True
,
para_joiner
=
'
\n\n
'
):
with
open
(
file
,
'
rb
'
)
as
fh
:
def
read
(
self
,
file
,
get_meta
=
False
,
autojoin_paragraphs
=
True
,
para_joiner
=
"
\n\n
"
):
with
open
(
file
,
"
rb
"
)
as
fh
:
self
.
fh
=
fh
cctx
=
zstandard
.
ZstdDecompressor
()
reader
=
io
.
BufferedReader
(
cctx
.
stream_reader
(
fh
))
...
...
@@ -52,16 +60,17 @@ class Reader:
yield
ob
continue
text
=
ob
[
'
text
'
]
text
=
ob
[
"
text
"
]
if
autojoin_paragraphs
and
isinstance
(
text
,
list
):
text
=
para_joiner
.
join
(
text
)
if
get_meta
:
yield
text
,
(
ob
[
'
meta
'
]
if
'
meta
'
in
ob
else
{})
yield
text
,
(
ob
[
"
meta
"
]
if
"
meta
"
in
ob
else
{})
else
:
yield
text
class
TextArchive
:
def
__init__
(
self
,
file_path
,
mode
=
"rb+"
):
self
.
file_path
=
file_path
...
...
@@ -75,12 +84,13 @@ class TextArchive:
self
.
fh
=
open
(
self
.
file_path
,
mode
)
def
add_data
(
self
,
data
):
self
.
fh
.
write
(
data
.
encode
(
'
UTF-8
'
)
+
b
'
\n
'
)
self
.
fh
.
write
(
data
.
encode
(
"
UTF-8
"
)
+
b
"
\n
"
)
def
commit
(
self
):
self
.
fh
.
flush
()
self
.
fh
.
close
()
class
TextReader
:
def
__init__
(
self
,
file_path
):
self
.
file_path
=
file_path
...
...
@@ -90,9 +100,12 @@ class TextReader:
def
read_tqdm
(
self
,
update_frequency
=
10000
):
current_file_position
=
0
line_counter
=
0
with
open
(
self
.
file_path
,
'r'
)
as
fh
,
\
tqdm
.
tqdm
(
total
=
os
.
path
.
getsize
(
self
.
file_path
),
dynamic_ncols
=
True
,
unit
=
"byte"
,
unit_scale
=
1
)
as
progress
:
with
open
(
self
.
file_path
,
"r"
)
as
fh
,
tqdm
.
tqdm
(
total
=
os
.
path
.
getsize
(
self
.
file_path
),
dynamic_ncols
=
True
,
unit
=
"byte"
,
unit_scale
=
1
,
)
as
progress
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
...
...
@@ -107,7 +120,7 @@ class TextReader:
def
read_and_tell
(
self
):
current_file_position
=
0
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
with
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
...
...
@@ -117,14 +130,14 @@ class TextReader:
yield
line
[:
-
1
],
raw_bytes_read
def
read
(
self
):
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
with
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
yield
line
[:
-
1
]
def
read_slow
(
self
):
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
with
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf8"
)
as
fh
:
while
True
:
line
=
fh
.
readline
()
if
line
==
-
1
or
line
==
""
:
...
...
@@ -132,6 +145,7 @@ class TextReader:
else
:
yield
line
[:
-
1
]
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class
ZStdTextReader
:
...
...
lm_eval/decontamination/decontaminate.py
View file @
121b7096
...
...
@@ -57,19 +57,27 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
os
.
mkdir
(
f
"data/
{
task_name
}
"
)
# Check if we've decontaminated this combination before
overlaps_dump_path
=
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
)
overlaps_dump_path
=
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
)
if
os
.
path
.
exists
(
overlaps_dump_path
):
duplicates
[(
task_name
,
task_set
)]
=
pickle
.
load
(
open
(
overlaps_dump_path
,
"rb"
))
duplicates
[(
task_name
,
task_set
)]
=
pickle
.
load
(
open
(
overlaps_dump_path
,
"rb"
)
)
sets_to_decontaminate
-=
1
continue
else
:
duplicates
[(
task_name
,
task_set
)]
=
set
()
# Build/load the task lookup {ngram: set(documents)}.
task_set_lookup_path
=
f
"data/
{
task_name
}
/
{
task_set
}
_
{
ngrams_n_size
}
grams_limit
{
limit
}
.lookup"
task_set_lookup_path
=
(
f
"data/
{
task_name
}
/
{
task_set
}
_
{
ngrams_n_size
}
grams_limit
{
limit
}
.lookup"
)
if
os
.
path
.
exists
(
task_set_lookup_path
):
print
(
f
"
{
task_set_lookup_path
}
available, loading..."
)
lookups
[(
task_name
,
task_set
)]
=
pickle
.
load
(
open
(
task_set_lookup_path
,
"rb"
))
lookups
[(
task_name
,
task_set
)]
=
pickle
.
load
(
open
(
task_set_lookup_path
,
"rb"
)
)
else
:
print
(
f
"
{
task_set_lookup_path
}
not available, building..."
)
lookup
=
collections
.
defaultdict
(
set
)
...
...
@@ -79,7 +87,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
for
ngram
in
ngrams
:
lookup
[
ngram
].
add
(
doc_id
)
pickle
.
dump
(
lookup
,
open
(
task_set_lookup_path
,
"wb"
))
pickle
.
dump
(
lookup
,
open
(
task_set_lookup_path
,
"wb"
))
lookups
[(
task_name
,
task_set
)]
=
lookup
elapsed
=
time
.
perf_counter
()
-
start
...
...
@@ -115,7 +123,9 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
for
line
in
reader
.
read_tqdm
():
# Scan training set ngrams file
total_ngrams
+=
1
[
ngram
,
document_id
]
=
line
.
rsplit
(
" "
,
1
)
if
ngram
!=
current_ngram
:
# Only need to match the ngram once in training set
if
(
ngram
!=
current_ngram
):
# Only need to match the ngram once in training set
unique_ngrams
+=
1
current_ngram
=
ngram
if
ngram
in
merged_lookup
:
...
...
@@ -123,7 +133,11 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
matching_unique
+=
1
for
task_name
,
task_set
,
doc_ids
in
merged_lookup
[
ngram
]:
task_doc_set
=
duplicates
[(
task_name
,
task_set
)]
for
doc_id
in
doc_ids
:
# Record contamination across all relevant task/set combos
for
(
doc_id
)
in
(
doc_ids
):
# Record contamination across all relevant task/set combos
task_doc_set
.
add
(
doc_id
)
del
merged_lookup
[
ngram
]
# No point matching again
else
:
...
...
@@ -145,9 +159,10 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# Dump overlaps separately
for
(
task_name
,
task_set
),
doc_ids
in
duplicates
.
items
():
overlaps_dump_path
=
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
)
pickle
.
dump
(
doc_ids
,
open
(
overlaps_dump_path
,
"wb"
))
overlaps_dump_path
=
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
)
pickle
.
dump
(
doc_ids
,
open
(
overlaps_dump_path
,
"wb"
))
# Strip task set and return
return
{
task_name
:
doc_ids
for
(
task_name
,
task_set
),
doc_ids
in
duplicates
.
items
()}
lm_eval/decontamination/janitor.py
View file @
121b7096
...
...
@@ -9,6 +9,7 @@ from pprint import pprint
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
try
:
import
janitor_util
JANITOR_CPP
=
True
except
Exception
as
e
:
print
(
"WARNING: C++ module could not be loaded. Janitor running in python mode"
)
...
...
@@ -41,6 +42,7 @@ def word_ngrams(s, n):
ngram_seqs
=
form_ngrams
(
iter
(
tokens
),
n
)
return
(
" "
.
join
(
ngram
)
for
ngram
in
ngram_seqs
)
# Does character sequences only - combined faster function to play around with later
# def word_ngrams_indices_combined(sequence, n):
# current_word = ""
...
...
@@ -70,7 +72,7 @@ def split_indices(s):
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return
((
m
.
group
(
0
),
(
m
.
start
(),
m
.
end
()
-
1
))
for
m
in
re
.
finditer
(
r
'
\S+
'
,
s
))
return
((
m
.
group
(
0
),
(
m
.
start
(),
m
.
end
()
-
1
))
for
m
in
re
.
finditer
(
r
"
\S+
"
,
s
))
def
word_ngrams_indices
(
s
,
n
):
...
...
@@ -90,10 +92,15 @@ def word_ngrams_indices(s, n):
# ([word, word, ...], [(start,end), (start,end), ...]),
# ...
# )
ngram_indices_pairs
=
(
zip
(
*
ngram_with_indices
)
for
ngram_with_indices
in
ngram_seqs_with_indices
)
ngram_indices_pairs
=
(
zip
(
*
ngram_with_indices
)
for
ngram_with_indices
in
ngram_seqs_with_indices
)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return
((
" "
.
join
(
ngram_seq
),
(
indices
[
0
][
0
],
indices
[
-
1
][
1
]))
for
ngram_seq
,
indices
in
ngram_indices_pairs
)
return
(
(
" "
.
join
(
ngram_seq
),
(
indices
[
0
][
0
],
indices
[
-
1
][
1
]))
for
ngram_seq
,
indices
in
ngram_indices_pairs
)
class
Janitor
:
...
...
@@ -105,7 +112,7 @@ class Janitor:
window_to_remove
=
200
,
too_dirty_cutoff
=
10
,
minimum_slice_length
=
200
,
delete_chars
=
string
.
punctuation
delete_chars
=
string
.
punctuation
,
):
self
.
ngram_n
=
ngram_n
self
.
window_to_remove
=
window_to_remove
...
...
@@ -121,7 +128,7 @@ class Janitor:
self
.
translation_table
=
str
.
maketrans
(
string
.
ascii_lowercase
+
string
.
ascii_uppercase
,
# These characters
string
.
ascii_lowercase
*
2
,
# Become these characters
self
.
delete_chars
# These are deleted
self
.
delete_chars
,
# These are deleted
)
##############
...
...
@@ -129,14 +136,13 @@ class Janitor:
##############
def
save_contamination_ngrams
(
self
,
filename
):
with
open
(
filename
,
'
wb
'
)
as
fp
:
with
open
(
filename
,
"
wb
"
)
as
fp
:
pickle
.
dump
(
filename
,
fp
)
def
load_contamination_ngrams
(
self
,
filename
):
with
open
(
filename
,
'
rb
'
)
as
fp
:
with
open
(
filename
,
"
rb
"
)
as
fp
:
self
.
dirt_ngrams
=
pickle
.
load
(
fp
)
##############
# Call these :)
##############
...
...
@@ -171,11 +177,11 @@ class Janitor:
end
=
min
(
len
(
dirty_string
),
end
+
self
.
window_to_remove
)
if
start
-
splice_idx
>
self
.
minimum_slice_length
:
clean_chunks
.
append
(
dirty_string
[
splice_idx
:
start
])
clean_chunks
.
append
(
dirty_string
[
splice_idx
:
start
])
splice_idx
=
end
if
end
<
len
(
dirty_string
)
-
self
.
minimum_slice_length
:
clean_chunks
.
append
(
dirty_string
[
end
+
1
:])
clean_chunks
.
append
(
dirty_string
[
end
+
1
:])
return
clean_chunks
...
...
@@ -184,10 +190,14 @@ class Janitor:
##############
def
register_contaminant_cpp
(
self
,
dirt_string
):
self
.
dirt_ngrams
.
update
(
janitor_util
.
clean_ngram
(
dirt_string
,
self
.
delete_chars
,
self
.
ngram_n
))
self
.
dirt_ngrams
.
update
(
janitor_util
.
clean_ngram
(
dirt_string
,
self
.
delete_chars
,
self
.
ngram_n
)
)
def
clean_cpp
(
self
,
dirty_string
):
contamination_indices
=
janitor_util
.
clean_ngram_with_indices
(
dirty_string
,
self
.
delete_chars
,
self
.
ngram_n
)
contamination_indices
=
janitor_util
.
clean_ngram_with_indices
(
dirty_string
,
self
.
delete_chars
,
self
.
ngram_n
)
return
self
.
_split_chunks
(
dirty_string
,
contamination_indices
)
##############
...
...
@@ -198,7 +208,9 @@ class Janitor:
return
s
.
translate
(
self
.
translation_table
)
def
register_contaminant_python
(
self
,
dirt_string
):
self
.
dirt_ngrams
.
update
(
word_ngrams
(
self
.
normalize_string
(
dirt_string
),
self
.
ngram_n
))
self
.
dirt_ngrams
.
update
(
word_ngrams
(
self
.
normalize_string
(
dirt_string
),
self
.
ngram_n
)
)
def
clean_python
(
self
,
dirty_string
):
contamination_indices
=
(
...
...
lm_eval/evaluator.py
View file @
121b7096
...
...
@@ -11,12 +11,22 @@ import numpy as np
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
from
lm_eval.decontamination.decontaminate
import
get_train_overlap
@
positional_deprecated
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
check_integrity
=
False
,
decontamination_ngrams_path
=
None
):
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
check_integrity
=
False
,
decontamination_ngrams_path
=
None
,
):
"""Instantiate and evaluate a model on a list of tasks.
...
...
@@ -52,17 +62,23 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert
tasks
!=
[],
"No tasks specified"
if
isinstance
(
model
,
str
):
if
model_args
is
None
:
model_args
=
""
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
'batch_size'
:
batch_size
,
'device'
:
device
})
if
model_args
is
None
:
model_args
=
""
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
"batch_size"
:
batch_size
,
"device"
:
device
}
)
else
:
assert
isinstance
(
model
,
lm_eval
.
base
.
LM
)
lm
=
model
if
not
no_cache
:
lm
=
lm_eval
.
base
.
CachingLM
(
lm
,
'lm_cache/'
+
model
+
'_'
+
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
).
replace
(
'/'
,
'-'
)
+
'.db'
lm
,
"lm_cache/"
+
model
+
"_"
+
model_args
.
replace
(
"="
,
"-"
).
replace
(
","
,
"_"
).
replace
(
"/"
,
"-"
)
+
".db"
,
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
...
...
@@ -89,16 +105,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache"
:
no_cache
,
"limit"
:
limit
,
"bootstrap_iters"
:
bootstrap_iters
,
"description_dict"
:
description_dict
"description_dict"
:
description_dict
,
}
return
results
decontaminate_suffix
=
"_decontaminate"
@
positional_deprecated
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
decontamination_ngrams_path
=
None
):
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
decontamination_ngrams_path
=
None
,
):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
...
...
@@ -124,14 +150,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert
not
provide_description
# not implemented.
if
provide_description
is
not
None
:
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
decontaminate
=
decontamination_ngrams_path
is
not
None
task_dict_items
=
[
(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
]
results
=
collections
.
defaultdict
(
dict
)
...
...
@@ -172,19 +200,22 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd
.
seed
(
42
)
rnd
.
shuffle
(
task_docs
)
description
=
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
description
=
(
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
)
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
if
decontaminate
and
task
.
should_decontaminate
():
docs_for_decontamination
[(
task_name
,
task_set
)].
append
(
task
.
doc_to_decontamination_query
(
doc
))
docs_for_decontamination
[(
task_name
,
task_set
)].
append
(
task
.
doc_to_decontamination_query
(
doc
)
)
docs
[(
task_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
...
...
@@ -198,7 +229,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan
if
decontaminate
:
print
(
"Finding train/test overlap, please wait..."
)
overlaps
=
get_train_overlap
(
docs_for_decontamination
,
decontamination_ngrams_path
,
limit
)
overlaps
=
get_train_overlap
(
docs_for_decontamination
,
decontamination_ngrams_path
,
limit
)
# all responses for each (task, doc)
process_res_queue
=
collections
.
defaultdict
(
list
)
...
...
@@ -212,7 +245,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print
(
"Running"
,
reqtype
,
"requests"
)
resps
=
getattr
(
lm
,
reqtype
)([
req
.
args
for
req
in
reqs
])
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)]
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
]
for
resp
,
(
i
,
task_name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
process_res_queue
[(
task_name
,
doc_id
)].
append
((
i
,
resp
))
...
...
@@ -241,7 +276,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
task
=
task_dict
[
task_name
]
real_metric
=
metric
# key when looking up the metric with task.aggregation
if
metric
.
endswith
(
decontaminate_suffix
):
real_metric
=
metric
.
replace
(
decontaminate_suffix
,
""
)
# decontaminated still uses the same metric
real_metric
=
metric
.
replace
(
decontaminate_suffix
,
""
)
# decontaminated still uses the same metric
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
real_metric
](
items
)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
...
...
@@ -249,16 +286,15 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
metric
=
task
.
aggregation
()[
real_metric
],
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
)
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)
}
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)}
def
make_table
(
result_dict
):
...
...
@@ -280,9 +316,9 @@ def make_table(result_dict):
if
m
+
"_stderr"
in
dic
:
se
=
dic
[
m
+
"_stderr"
]
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
'±'
,
'
%.4f
'
%
se
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
"±"
,
"
%.4f
"
%
se
])
else
:
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
''
,
''
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
""
,
""
])
k
=
""
version
=
""
md_writer
.
value_matrix
=
values
...
...
lm_eval/metrics.py
View file @
121b7096
...
...
@@ -103,6 +103,7 @@ def weighted_mean(items):
def
weighted_perplexity
(
items
):
return
math
.
exp
(
-
weighted_mean
(
items
))
def
bits_per_byte
(
items
):
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
...
...
@@ -184,8 +185,10 @@ def _sacreformat(refs, preds):
return
refs
,
preds
# stderr stuff
class
_bootstrap_internal
:
def
__init__
(
self
,
f
,
n
):
self
.
f
=
f
...
...
@@ -203,6 +206,7 @@ class _bootstrap_internal:
def
bootstrap_stderr
(
f
,
xs
,
iters
):
import
multiprocessing
as
mp
pool
=
mp
.
Pool
(
mp
.
cpu_count
())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# equivalent to stderr calculated without Bessel's correction in the stddev.
...
...
@@ -213,10 +217,15 @@ def bootstrap_stderr(f, xs, iters):
res
=
[]
chunk_size
=
min
(
1000
,
iters
)
from
tqdm
import
tqdm
print
(
"bootstrapping for stddev:"
,
f
.
__name__
)
for
bootstrap
in
tqdm
(
pool
.
imap
(
for
bootstrap
in
tqdm
(
pool
.
imap
(
_bootstrap_internal
(
f
,
chunk_size
),
[(
i
,
xs
)
for
i
in
range
(
iters
//
chunk_size
)]),
total
=
iters
//
chunk_size
):
[(
i
,
xs
)
for
i
in
range
(
iters
//
chunk_size
)],
),
total
=
iters
//
chunk_size
,
):
# sample w replacement
res
.
extend
(
bootstrap
)
...
...
@@ -238,17 +247,13 @@ def stderr_for_metric(metric, bootstrap_iters):
if
metric
in
bootstrappable
:
return
lambda
x
:
bootstrap_stderr
(
metric
,
x
,
iters
=
bootstrap_iters
)
stderr
=
{
mean
:
mean_stderr
,
acc_all
:
acc_all_stderr
}
stderr
=
{
mean
:
mean_stderr
,
acc_all
:
acc_all_stderr
}
return
stderr
.
get
(
metric
,
None
)
def
yesno
(
x
):
if
x
:
return
'
yes
'
return
"
yes
"
else
:
return
'
no
'
return
"
no
"
lm_eval/models/gpt2.py
View file @
121b7096
...
...
@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM
class
HFLM
(
BaseLM
):
def
__init__
(
self
,
device
=
'cuda'
,
pretrained
=
'gpt2'
,
revision
=
'main'
,
subfolder
=
None
,
tokenizer
=
None
,
batch_size
=
1
):
def
__init__
(
self
,
device
=
"cuda"
,
pretrained
=
"gpt2"
,
revision
=
"main"
,
subfolder
=
None
,
tokenizer
=
None
,
batch_size
=
1
,
):
super
().
__init__
()
assert
isinstance
(
device
,
str
)
...
...
@@ -20,28 +27,47 @@ class HFLM(BaseLM):
else
:
print
(
"Device not specificed"
)
print
(
f
"Cuda Available?
{
torch
.
cuda
.
is_available
()
}
"
)
self
.
_device
=
torch
.
device
(
'cuda'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
self
.
_device
=
(
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
)
# TODO: update this to be less of a hack once subfolder is fixed in HF
self
.
gpt2
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
pretrained
,
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
)
pretrained
,
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
),
).
to
(
self
.
device
)
self
.
gpt2
.
eval
()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
pretrained
if
tokenizer
is
None
else
tokenizer
,
revision
=
revision
,
subfolder
=
subfolder
)
pretrained
if
tokenizer
is
None
else
tokenizer
,
revision
=
revision
,
subfolder
=
subfolder
,
)
assert
isinstance
(
self
.
tokenizer
,
(
transformers
.
GPT2Tokenizer
,
transformers
.
GPT2TokenizerFast
,
transformers
.
T5Tokenizer
,
transformers
.
T5TokenizerFast
,
)),
"this tokenizer has not been checked for compatibility yet!"
assert
isinstance
(
self
.
tokenizer
,
(
transformers
.
GPT2Tokenizer
,
transformers
.
GPT2TokenizerFast
,
transformers
.
T5Tokenizer
,
transformers
.
T5TokenizerFast
,
),
),
"this tokenizer has not been checked for compatibility yet!"
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
if
isinstance
(
self
.
tokenizer
,
(
transformers
.
GPT2Tokenizer
,
transformers
.
GPT2TokenizerFast
)):
assert
self
.
tokenizer
.
encode
(
'hello
\n\n
hello'
)
==
[
31373
,
198
,
198
,
31373
],
\
self
.
tokenizer
.
encode
(
'hello
\n\n
hello'
)
if
isinstance
(
self
.
tokenizer
,
(
transformers
.
GPT2Tokenizer
,
transformers
.
GPT2TokenizerFast
)
):
assert
self
.
tokenizer
.
encode
(
"hello
\n\n
hello"
)
==
[
31373
,
198
,
198
,
31373
,
],
self
.
tokenizer
.
encode
(
"hello
\n\n
hello"
)
# multithreading and batching
self
.
batch_size_per_gpu
=
batch_size
# todo: adaptive batch size
...
...
@@ -97,10 +123,7 @@ class HFLM(BaseLM):
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
return
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
False
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
False
)
...
...
lm_eval/models/gpt3.py
View file @
121b7096
...
...
@@ -36,17 +36,19 @@ def get_result(response, ctxlen):
def
oa_completion
(
**
kwargs
):
"""
Query OpenAI API for completion.
"""Query OpenAI API for completion.
Retry with back-off until they respond
"""
import
openai
backoff_time
=
3
while
True
:
try
:
return
openai
.
Completion
.
create
(
**
kwargs
)
except
openai
.
error
.
OpenAIError
:
import
traceback
traceback
.
print_exc
()
time
.
sleep
(
backoff_time
)
backoff_time
*=
1.5
...
...
@@ -66,16 +68,19 @@ class GPT3LM(BaseLM):
super
().
__init__
()
import
openai
self
.
engine
=
engine
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
'
gpt2
'
)
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
"
gpt2
"
)
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
# to make the annoying "Using pad_token, but it is not set yet." error go away
self
.
tokenizer
.
pad_token
=
"<|endoftext|>"
assert
self
.
tokenizer
.
encode
(
'
hello
\n\n
hello
'
)
==
[
31373
,
198
,
198
,
31373
]
assert
self
.
tokenizer
.
encode
(
"
hello
\n\n
hello
"
)
==
[
31373
,
198
,
198
,
31373
]
self
.
truncate
=
truncate
self
.
end_of_text_token_id
=
self
.
tokenizer
.
convert_tokens_to_ids
([
"<|endoftext|>"
])[
0
]
self
.
end_of_text_token_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
[
"<|endoftext|>"
]
)[
0
]
# Read from environment variable OPENAI_API_SECRET_KEY
openai
.
api_key
=
os
.
environ
[
"OPENAI_API_SECRET_KEY"
]
...
...
@@ -121,14 +126,19 @@ class GPT3LM(BaseLM):
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
tqdm
(
list
(
utils
.
chunks
(
reord
.
get_reordered
(),
self
.
REQ_CHUNK_SIZE
)),
disable
=
disable_tqdm
):
for
chunk
in
tqdm
(
list
(
utils
.
chunks
(
reord
.
get_reordered
(),
self
.
REQ_CHUNK_SIZE
)),
disable
=
disable_tqdm
,
):
inps
=
[]
ctxlens
=
[]
for
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
# max_length+1 because the API takes up to 2049 tokens, including the first context token
inp
=
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
):]
inp
=
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:]
# TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
(
self
.
max_length
+
1
))
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
(
self
.
max_length
+
1
)
)
inps
.
append
(
inp
)
ctxlens
.
append
(
ctxlen
)
...
...
@@ -137,11 +147,14 @@ class GPT3LM(BaseLM):
engine
=
self
.
engine
,
prompt
=
inps
,
echo
=
True
,
max_tokens
=
0
,
temperature
=
0.
,
max_tokens
=
0
,
temperature
=
0.0
,
logprobs
=
10
,
)
for
resp
,
ctxlen
,
(
cache_key
,
context_enc
,
continuation_enc
)
in
zip
(
response
.
choices
,
ctxlens
,
chunk
):
for
resp
,
ctxlen
,
(
cache_key
,
context_enc
,
continuation_enc
)
in
zip
(
response
.
choices
,
ctxlens
,
chunk
):
answer
=
get_result
(
resp
,
ctxlen
)
res
.
append
(
answer
)
...
...
@@ -177,24 +190,26 @@ class GPT3LM(BaseLM):
yield
ret
,
lastuntil
# todo: more intelligent batching for heterogeneous `until`
for
chunk
,
until
in
tqdm
(
list
(
sameuntil_chunks
(
reord
.
get_reordered
(),
self
.
REQ_CHUNK_SIZE
))):
for
chunk
,
until
in
tqdm
(
list
(
sameuntil_chunks
(
reord
.
get_reordered
(),
self
.
REQ_CHUNK_SIZE
))
):
inps
=
[]
for
context
,
_
in
chunk
:
context_enc
=
self
.
tok_encode
(
context
)
inp
=
context_enc
[
-
(
self
.
max_length
-
self
.
max_gen_toks
):]
inp
=
context_enc
[
-
(
self
.
max_length
-
self
.
max_gen_toks
)
:]
inps
.
append
(
inp
)
response
=
oa_completion
(
engine
=
self
.
engine
,
prompt
=
inps
,
max_tokens
=
self
.
max_gen_toks
,
temperature
=
0.
,
temperature
=
0.
0
,
logprobs
=
10
,
stop
=
until
,
)
for
resp
,
(
context
,
until_
)
in
zip
(
response
.
choices
,
chunk
):
s
=
resp
[
'
text
'
]
s
=
resp
[
"
text
"
]
for
term
in
until_
:
s
=
s
.
split
(
term
)[
0
]
...
...
lm_eval/tasks/__init__.py
View file @
121b7096
...
...
@@ -59,8 +59,8 @@ from . import storycloze
# 6 total
gpt3_translation_benchmarks
=
{
"wmt14"
:
[
'
en-fr
'
,
'
fr-en
'
],
# French
"wmt16"
:
[
'
en-ro
'
,
'
ro-en
'
,
'
de-en
'
,
'
en-de
'
],
# German, Romanian
"wmt14"
:
[
"
en-fr
"
,
"
fr-en
"
],
# French
"wmt16"
:
[
"
en-ro
"
,
"
ro-en
"
,
"
de-en
"
,
"
en-de
"
],
# German, Romanian
}
...
...
@@ -68,7 +68,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks
=
{
**
gpt3_translation_benchmarks
,
"wmt20"
:
sacrebleu
.
get_langpairs_for_testset
(
"wmt20"
),
"iwslt17"
:
[
'
en-ar
'
,
'
ar-en
'
]
# Arabic
"iwslt17"
:
[
"
en-ar
"
,
"
ar-en
"
],
# Arabic
}
# 319 total
...
...
@@ -92,7 +92,7 @@ TASK_REGISTRY = {
"rte"
:
glue
.
RTE
,
"qnli"
:
glue
.
QNLI
,
"qqp"
:
glue
.
QQP
,
#"stsb": glue.STSB, # not implemented yet
#
"stsb": glue.STSB, # not implemented yet
"sst"
:
glue
.
SST
,
"wnli"
:
glue
.
WNLI
,
# SuperGLUE
...
...
@@ -103,34 +103,26 @@ TASK_REGISTRY = {
"record"
:
superglue
.
ReCoRD
,
"wic"
:
superglue
.
WordsInContext
,
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
# Order by benchmark/genre?
"coqa"
:
coqa
.
CoQA
,
"drop"
:
drop
.
DROP
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada_cloze"
:
lambada_cloze
.
LAMBADA_cloze
,
# multilingual lambada
**
lambada_multilingual
.
construct_tasks
(),
"wikitext"
:
wikitext
.
WikiText
,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa"
:
piqa
.
PiQA
,
"prost"
:
prost
.
PROST
,
"mc_taco"
:
mc_taco
.
MCTACO
,
# Science related
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"qasper"
:
qasper
.
QASPER
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"triviaqa"
:
triviaqa
.
TriviaQA
,
"arc_easy"
:
arc
.
ARCEasy
,
"arc_challenge"
:
arc
.
ARCChallenge
,
...
...
@@ -152,21 +144,17 @@ TASK_REGISTRY = {
"anli_r1"
:
anli
.
ANLIRound1
,
"anli_r2"
:
anli
.
ANLIRound2
,
"anli_r3"
:
anli
.
ANLIRound3
,
"ethics_cm"
:
hendrycks_ethics
.
EthicsCM
,
"ethics_deontology"
:
hendrycks_ethics
.
EthicsDeontology
,
"ethics_justice"
:
hendrycks_ethics
.
EthicsJustice
,
"ethics_utilitarianism_original"
:
hendrycks_ethics
.
EthicsUtilitarianismOriginal
,
"ethics_utilitarianism"
:
hendrycks_ethics
.
EthicsUtilitarianism
,
"ethics_virtue"
:
hendrycks_ethics
.
EthicsVirtue
,
"truthfulqa_mc"
:
truthfulqa
.
TruthfulQAMultipleChoice
,
"truthfulqa_gen"
:
truthfulqa
.
TruthfulQAGeneration
,
# dialogue
"mutual"
:
mutual
.
MuTual
,
"mutual_plus"
:
mutual
.
MuTualPlus
,
# math
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
...
...
@@ -177,7 +165,6 @@ TASK_REGISTRY = {
"math_precalc"
:
hendrycks_math
.
MathPrecalculus
,
"math_asdiv"
:
asdiv
.
Asdiv
,
"gsm8k"
:
gsm8k
.
GradeSchoolMath8K
,
# arithmetic
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
"arithmetic_2ds"
:
arithmetic
.
Arithmetic2DMinus
,
...
...
@@ -191,22 +178,18 @@ TASK_REGISTRY = {
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks)
**
hendrycks_test
.
create_all_tasks
(),
# e.g. wmt14-fr-en
**
translation
.
create_tasks_from_benchmarks
(
gpt3_translation_benchmarks
),
# chef's selection, mostly wmt20
**
translation
.
create_tasks_from_benchmarks
(
selected_translation_benchmarks
),
# Word Scrambling and Manipulation Tasks
"anagrams1"
:
unscramble
.
Anagrams1
,
"anagrams2"
:
unscramble
.
Anagrams2
,
"cycle_letters"
:
unscramble
.
CycleLetters
,
"random_insertion"
:
unscramble
.
RandomInsertion
,
"reversed_words"
:
unscramble
.
ReversedWords
,
# Pile
"pile_arxiv"
:
pile
.
PileArxiv
,
"pile_books3"
:
pile
.
PileBooks3
,
...
...
@@ -230,7 +213,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc"
:
pile
.
PileUbuntuIrc
,
"pile_wikipedia"
:
pile
.
PileWikipedia
,
"pile_youtubesubtitles"
:
pile
.
PileYoutubeSubtitles
,
# BLiMP
"blimp_adjunct_island"
:
blimp
.
BlimpAdjunctIsland
,
"blimp_anaphor_gender_agreement"
:
blimp
.
BlimpAnaphorGenderAgreement
,
...
...
@@ -299,7 +281,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance"
:
blimp
.
BlimpWhVsThatNoGapLongDistance
,
"blimp_wh_vs_that_with_gap"
:
blimp
.
BlimpWhVsThatWithGap
,
"blimp_wh_vs_that_with_gap_long_distance"
:
blimp
.
BlimpWhVsThatWithGapLongDistance
,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
...
...
@@ -325,17 +306,23 @@ def get_task_name_from_object(task_object):
return
name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
return
(
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
)
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
lm_eval
.
base
.
Task
]]):
task_name_dict
=
{
task_name
:
get_task
(
task_name
)()
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
}
task_name_from_object_dict
=
{
get_task_name_from_object
(
task_object
):
task_object
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
}
assert
set
(
task_name_dict
.
keys
()).
isdisjoint
(
set
(
task_name_from_object_dict
.
keys
()))
return
{
**
task_name_dict
,
**
task_name_from_object_dict
}
lm_eval/tasks/anli.py
View file @
121b7096
...
...
@@ -64,7 +64,12 @@ class ANLIBase(Task):
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# appended onto the question, with no "Answer:" or even a newline. Do we *really*
# want to do it exactly as OA did?
return
doc
[
'premise'
]
+
'
\n
Question: '
+
doc
[
'hypothesis'
]
+
' True, False, or Neither?
\n
Answer:'
return
(
doc
[
"premise"
]
+
"
\n
Question: "
+
doc
[
"hypothesis"
]
+
" True, False, or Neither?
\n
Answer:"
)
def
should_decontaminate
(
self
):
return
True
...
...
@@ -76,10 +81,10 @@ class ANLIBase(Task):
# True = entailment
# False = contradiction
# Neither = neutral
return
" "
+
[
"True"
,
"Neither"
,
"False"
][
doc
[
'
label
'
]]
return
" "
+
[
"True"
,
"Neither"
,
"False"
][
doc
[
"
label
"
]]
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
...
...
@@ -106,9 +111,7 @@ class ANLIBase(Task):
"""
gold
=
doc
[
"label"
]
pred
=
np
.
argmax
(
results
)
return
{
"acc"
:
pred
==
gold
}
return
{
"acc"
:
pred
==
gold
}
def
aggregation
(
self
):
"""
...
...
@@ -116,9 +119,7 @@ class ANLIBase(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return
{
"acc"
:
mean
}
return
{
"acc"
:
mean
}
def
higher_is_better
(
self
):
"""
...
...
@@ -126,9 +127,7 @@ class ANLIBase(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"acc"
:
True
}
return
{
"acc"
:
True
}
class
ANLIRound1
(
ANLIBase
):
...
...
lm_eval/tasks/arithmetic.py
View file @
121b7096
...
...
@@ -67,10 +67,8 @@ class Arithmetic(Task):
return
is_prediction
def
process_results
(
self
,
doc
,
results
):
is_prediction
,
=
results
return
{
"acc"
:
is_prediction
}
(
is_prediction
,)
=
results
return
{
"acc"
:
is_prediction
}
def
aggregation
(
self
):
return
{
...
...
@@ -78,9 +76,7 @@ class Arithmetic(Task):
}
def
higher_is_better
(
self
):
return
{
"acc"
:
True
}
return
{
"acc"
:
True
}
class
Arithmetic2DPlus
(
Arithmetic
):
...
...
lm_eval/tasks/asdiv.py
View file @
121b7096
...
...
@@ -54,29 +54,28 @@ class Asdiv(Task):
def
test_docs
(
self
):
raise
NotImplementedError
(
"This dataset has no test docs"
)
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
num_fewshot
==
0
,
"ASDiv is intended only for the zero-shot setting."
return
super
().
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
def
doc_to_text
(
self
,
doc
):
# TODO: add solution-type
return
doc
[
'
body
'
]
+
'
\n
'
+
'
Question:
'
+
doc
[
'
question
'
]
+
'
\n
'
+
'
Answer:
'
return
doc
[
"
body
"
]
+
"
\n
"
+
"
Question:
"
+
doc
[
"
question
"
]
+
"
\n
"
+
"
Answer:
"
def
should_decontaminate
(
self
):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
'
body
'
]
+
" "
+
doc
[
'
question
'
]
return
doc
[
"
body
"
]
+
" "
+
doc
[
"
question
"
]
def
doc_to_target
(
self
,
doc
):
# TODO: add formula
answer
=
doc
[
'
answer
'
].
split
(
'
(
'
)[
0
]
answer
=
doc
[
"
answer
"
].
split
(
"
(
"
)[
0
]
return
" "
+
answer
def
construct_requests
(
self
,
doc
,
ctx
):
...
...
@@ -86,16 +85,10 @@ class Asdiv(Task):
def
process_results
(
self
,
doc
,
results
):
ll
,
is_greedy
=
results
return
{
'acc'
:
int
(
is_greedy
)
}
return
{
"acc"
:
int
(
is_greedy
)}
def
aggregation
(
self
):
return
{
'acc'
:
mean
}
return
{
"acc"
:
mean
}
def
higher_is_better
(
self
):
return
{
'acc'
:
True
}
return
{
"acc"
:
True
}
lm_eval/tasks/blimp.py
View file @
121b7096
...
...
@@ -50,9 +50,13 @@ class BlimpTask(Task):
# trained on this data.
return
self
.
dataset
[
"train"
]
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
num_fewshot
==
0
assert
rnd
is
not
None
,
"A `random.Random` generator argument must be provided to `rnd`"
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`"
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
...
...
@@ -60,7 +64,9 @@ class BlimpTask(Task):
)
if
provide_description
is
not
None
:
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return
""
...
...
lm_eval/tasks/cbt.py
View file @
121b7096
...
...
@@ -86,7 +86,9 @@ class CBTBase(Task):
return
""
def
fewshot_examples
(
self
,
k
,
rnd
):
assert
k
==
0
,
f
"CBT is only implemented for the zero-shot setting. Given k=
{
k
}
."
assert
(
k
==
0
),
f
"CBT is only implemented for the zero-shot setting. Given k=
{
k
}
."
return
super
().
fewshot_examples
(
k
,
rnd
)
def
construct_requests
(
self
,
doc
,
ctx
):
...
...
@@ -120,9 +122,7 @@ class CBTBase(Task):
"""
gold
=
doc
[
"options"
].
index
(
doc
[
"answer"
])
pred
=
np
.
argmax
(
results
)
return
{
"acc"
:
pred
==
gold
}
return
{
"acc"
:
pred
==
gold
}
def
aggregation
(
self
):
"""
...
...
@@ -130,9 +130,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return
{
"acc"
:
mean
}
return
{
"acc"
:
mean
}
def
higher_is_better
(
self
):
"""
...
...
@@ -140,9 +138,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"acc"
:
True
}
return
{
"acc"
:
True
}
class
CBTCN
(
CBTBase
):
...
...
lm_eval/tasks/coqa.py
View file @
121b7096
...
...
@@ -54,8 +54,10 @@ class CoQA(Task):
def
doc_to_text
(
self
,
doc
):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text
=
doc
[
"story"
]
+
'
\n\n
'
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
][
"input_text"
],
doc
[
"answers"
][
"input_text"
][:
-
1
]):
# omit target answer ai
doc_text
=
doc
[
"story"
]
+
"
\n\n
"
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
][
"input_text"
],
doc
[
"answers"
][
"input_text"
][:
-
1
]
):
# omit target answer ai
question
=
f
"Q:
{
q
}
\n\n
"
answer
=
f
"A:
{
a
}
\n\n
"
if
a
is
not
None
else
"A:"
doc_text
+=
question
+
answer
...
...
@@ -77,7 +79,9 @@ class CoQA(Task):
additional_answers
=
doc
.
get
(
"additional_answers"
)
if
additional_answers
:
for
key
in
additional_answers
:
additional_answer_for_turn
=
additional_answers
[
key
][
"input_text"
][
turn_id
-
1
]
additional_answer_for_turn
=
additional_answers
[
key
][
"input_text"
][
turn_id
-
1
]
if
additional_answer_for_turn
.
lower
()
not
in
map
(
str
.
lower
,
answers
):
answers
.
append
(
additional_answer_for_turn
)
return
answers
...
...
@@ -89,12 +93,12 @@ class CoQA(Task):
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if
raw_text
==
"unknown"
:
return
'0'
return
"0"
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"yes"
:
return
'1'
return
"1"
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"no"
:
return
'2'
return
'3'
# Not a yes/no question
return
"2"
return
"3"
# Not a yes/no question
@
staticmethod
def
compute_scores
(
gold_list
,
pred
):
...
...
@@ -104,25 +108,30 @@ class CoQA(Task):
em_sum
=
0.0
if
len
(
gold_list
)
>
1
:
for
i
in
range
(
len
(
gold_list
)):
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
# predictions compared against (n) golds and take maximum
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_answers
)
else
:
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_list
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_list
)
return
{
'em'
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
'f1'
:
f1_sum
/
max
(
1
,
len
(
gold_list
))}
return
{
"em"
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
"f1"
:
f1_sum
/
max
(
1
,
len
(
gold_list
)),
}
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
# Default to prediction of last turn.
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
][
"input_text"
])
raw_text
=
doc
[
'
answers
'
][
"input_text"
][
turnid
-
1
]
raw_text
=
doc
[
"
answers
"
][
"input_text"
][
turnid
-
1
]
return
" "
+
raw_text
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
...
...
@@ -132,7 +141,7 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
cont_request
=
rf
.
greedy_until
(
ctx
,
[
'
\n
Q:
'
])
cont_request
=
rf
.
greedy_until
(
ctx
,
[
"
\n
Q:
"
])
return
cont_request
def
process_results
(
self
,
doc
,
results
):
...
...
@@ -147,13 +156,13 @@ class CoQA(Task):
"""
turn_id
=
len
(
doc
[
"questions"
][
"input_text"
])
gold_list
=
self
.
get_answers
(
doc
,
turn_id
)
pred
=
results
[
0
].
strip
().
split
(
'
\n
'
)[
0
]
pred
=
results
[
0
].
strip
().
split
(
"
\n
"
)[
0
]
scores
=
self
.
compute_scores
(
gold_list
,
pred
)
return
{
"f1"
:
scores
[
'
f1
'
],
"em"
:
scores
[
'
em
'
],
"f1"
:
scores
[
"
f1
"
],
"em"
:
scores
[
"
em
"
],
}
def
higher_is_better
(
self
):
...
...
lm_eval/tasks/drop.py
View file @
121b7096
...
...
@@ -70,21 +70,26 @@ class DROP(Task):
@
classmethod
def
get_answers
(
cls
,
qa
):
def
_flatten_validated_answers
(
validated_answers
):
"""
Flattens a dict of lists of validated answers.
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas
=
[]
for
i
in
range
(
len
(
validated_answers
[
"number"
])):
vas
.
append
({
vas
.
append
(
{
"number"
:
validated_answers
[
"number"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
})
}
)
return
vas
answers
=
[]
answers_set
=
set
()
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
])
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
]
)
for
candidate
in
candidates
:
answer
=
cls
.
parse_answer
(
candidate
)
if
answer
in
answers_set
:
...
...
@@ -100,9 +105,11 @@ class DROP(Task):
return
(
str
(
answer
[
"number"
]),)
if
answer
[
"spans"
]
!=
[]:
return
tuple
(
answer
[
"spans"
])
return
(
" "
.
join
([
answer
[
"date"
][
"day"
],
answer
[
"date"
][
"month"
],
answer
[
"date"
][
"year"
]]).
strip
(),)
return
(
" "
.
join
(
[
answer
[
"date"
][
"day"
],
answer
[
"date"
][
"month"
],
answer
[
"date"
][
"year"
]]
).
strip
(),
)
def
doc_to_text
(
self
,
doc
):
return
f
"Passage:
{
doc
[
'passage'
]
}
\n
Question:
{
doc
[
'question'
]
}
\n
Answer:"
...
...
@@ -111,7 +118,7 @@ class DROP(Task):
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
'
passage
'
]
+
" "
+
doc
[
'
question
'
]
return
doc
[
"
passage
"
]
+
" "
+
doc
[
"
question
"
]
def
doc_to_target
(
self
,
doc
):
return
" "
+
", "
.
join
(
doc
[
"answers"
][
0
])
...
...
@@ -148,10 +155,7 @@ class DROP(Task):
if
gold_answer
[
0
].
strip
():
max_em
=
max
(
max_em
,
exact_match
)
max_f1
=
max
(
max_f1
,
f1_score
)
return
{
"em"
:
max_em
,
"f1"
:
max_f1
}
return
{
"em"
:
max_em
,
"f1"
:
max_f1
}
def
get_metrics
(
self
,
predicted
,
gold
):
"""
...
...
@@ -164,7 +168,9 @@ class DROP(Task):
predicted_bags
=
self
.
_answer_to_bags
(
predicted
)
gold_bags
=
self
.
_answer_to_bags
(
gold
)
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
])
==
len
(
gold_bags
[
0
]):
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
]
)
==
len
(
gold_bags
[
0
]):
exact_match
=
1.0
else
:
exact_match
=
0.0
...
...
@@ -196,7 +202,9 @@ class DROP(Task):
for
gold_index
,
gold_item
in
enumerate
(
gold
):
for
pred_index
,
pred_item
in
enumerate
(
predicted
):
if
self
.
_match_numbers_if_present
(
gold_item
,
pred_item
):
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
row_ind
,
col_ind
=
linear_sum_assignment
(
-
scores
)
max_scores
=
np
.
zeros
([
max
(
len
(
gold
),
len
(
predicted
))])
...
...
@@ -262,7 +270,11 @@ class DROP(Task):
def
_normalize
(
self
,
answer
):
tokens
=
[
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))))
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))
)
)
for
token
in
self
.
_tokenize
(
answer
)
]
tokens
=
[
token
for
token
in
tokens
if
token
.
strip
()]
...
...
@@ -275,10 +287,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return
{
"em"
:
mean
,
"f1"
:
mean
}
return
{
"em"
:
mean
,
"f1"
:
mean
}
def
higher_is_better
(
self
):
"""
...
...
@@ -286,7 +295,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"em"
:
True
,
"f1"
:
True
}
return
{
"em"
:
True
,
"f1"
:
True
}
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment