Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
121b7096
Commit
121b7096
authored
May 02, 2022
by
Fabrizio Milo
Browse files
add pre-commit
parent
7a038118
Changes
120
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
544 additions
and
396 deletions
+544
-396
lm_eval/datasets/truthfulqa/truthfulqa.py
lm_eval/datasets/truthfulqa/truthfulqa.py
+36
-26
lm_eval/datasets/unscramble/dataset_infos.json
lm_eval/datasets/unscramble/dataset_infos.json
+1
-1
lm_eval/datasets/unscramble/unscramble.py
lm_eval/datasets/unscramble/unscramble.py
+3
-2
lm_eval/datasets/wikitext/dataset_infos.json
lm_eval/datasets/wikitext/dataset_infos.json
+1
-1
lm_eval/datasets/wikitext/wikitext.py
lm_eval/datasets/wikitext/wikitext.py
+55
-30
lm_eval/decontamination/archiver.py
lm_eval/decontamination/archiver.py
+47
-33
lm_eval/decontamination/decontaminate.py
lm_eval/decontamination/decontaminate.py
+32
-17
lm_eval/decontamination/janitor.py
lm_eval/decontamination/janitor.py
+45
-33
lm_eval/evaluator.py
lm_eval/evaluator.py
+80
-44
lm_eval/metrics.py
lm_eval/metrics.py
+15
-10
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+42
-19
lm_eval/models/gpt3.py
lm_eval/models/gpt3.py
+36
-21
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+22
-35
lm_eval/tasks/anli.py
lm_eval/tasks/anli.py
+23
-24
lm_eval/tasks/arithmetic.py
lm_eval/tasks/arithmetic.py
+4
-8
lm_eval/tasks/asdiv.py
lm_eval/tasks/asdiv.py
+12
-19
lm_eval/tasks/blimp.py
lm_eval/tasks/blimp.py
+10
-4
lm_eval/tasks/cbt.py
lm_eval/tasks/cbt.py
+8
-12
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+37
-28
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+35
-29
No files found.
lm_eval/datasets/truthfulqa/truthfulqa.py
View file @
121b7096
...
@@ -65,13 +65,14 @@ class TruthfulqaConfig(datasets.BuilderConfig):
...
@@ -65,13 +65,14 @@ class TruthfulqaConfig(datasets.BuilderConfig):
class
Truthfulqa
(
datasets
.
GeneratorBasedBuilder
):
class
Truthfulqa
(
datasets
.
GeneratorBasedBuilder
):
"""TruthfulQA is a benchmark to measure whether a language model is truthful in
"""TruthfulQA is a benchmark to measure whether a language model is truthful in
generating answers to questions."""
generating answers to questions."""
BUILDER_CONFIGS
=
[
BUILDER_CONFIGS
=
[
TruthfulqaConfig
(
TruthfulqaConfig
(
name
=
"multiple_choice"
,
name
=
"multiple_choice"
,
url
=
"https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json"
,
url
=
"https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json"
,
features
=
datasets
.
Features
({
features
=
datasets
.
Features
(
{
"question"
:
datasets
.
Value
(
"string"
),
"question"
:
datasets
.
Value
(
"string"
),
"mc1_targets"
:
{
"mc1_targets"
:
{
"choices"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)),
"choices"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)),
...
@@ -80,23 +81,30 @@ generating answers to questions."""
...
@@ -80,23 +81,30 @@ generating answers to questions."""
"mc2_targets"
:
{
"mc2_targets"
:
{
"choices"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)),
"choices"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)),
"labels"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"int32"
)),
"labels"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"int32"
)),
},
}
}
}
),
),
description
=
"The multiple choice TruthfulQA task"
description
=
"The multiple choice TruthfulQA task"
,
),
),
TruthfulqaConfig
(
TruthfulqaConfig
(
name
=
"generation"
,
name
=
"generation"
,
url
=
"https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv"
,
url
=
"https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv"
,
features
=
datasets
.
Features
({
features
=
datasets
.
Features
(
{
"category"
:
datasets
.
Value
(
"string"
),
"category"
:
datasets
.
Value
(
"string"
),
"question"
:
datasets
.
Value
(
"string"
),
"question"
:
datasets
.
Value
(
"string"
),
"best_answer"
:
datasets
.
Value
(
"string"
),
"best_answer"
:
datasets
.
Value
(
"string"
),
"correct_answers"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)),
"correct_answers"
:
datasets
.
features
.
Sequence
(
"incorrect_answers"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)),
datasets
.
Value
(
"string"
)
),
"incorrect_answers"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)
),
"source"
:
datasets
.
Value
(
"string"
),
"source"
:
datasets
.
Value
(
"string"
),
}),
}
description
=
"The generative TruthfulQA task"
),
)
description
=
"The generative TruthfulQA task"
,
),
]
]
def
_info
(
self
):
def
_info
(
self
):
...
@@ -138,15 +146,15 @@ generating answers to questions."""
...
@@ -138,15 +146,15 @@ generating answers to questions."""
"mc2_targets"
:
{
"mc2_targets"
:
{
"choices"
:
row
[
"mc2_targets"
].
keys
(),
"choices"
:
row
[
"mc2_targets"
].
keys
(),
"labels"
:
row
[
"mc2_targets"
].
values
(),
"labels"
:
row
[
"mc2_targets"
].
values
(),
}
}
,
}
}
else
:
else
:
# Generation data is in a `CSV` file.
# Generation data is in a `CSV` file.
with
open
(
filepath
,
newline
=
''
)
as
f
:
with
open
(
filepath
,
newline
=
""
)
as
f
:
contents
=
csv
.
DictReader
(
f
)
contents
=
csv
.
DictReader
(
f
)
for
key
,
row
in
enumerate
(
contents
):
for
key
,
row
in
enumerate
(
contents
):
# Ensure that references exist.
# Ensure that references exist.
if
not
row
[
'
Correct Answers
'
]
or
not
row
[
'
Incorrect Answers
'
]:
if
not
row
[
"
Correct Answers
"
]
or
not
row
[
"
Incorrect Answers
"
]:
continue
continue
yield
key
,
{
yield
key
,
{
"category"
:
row
[
"Category"
],
"category"
:
row
[
"Category"
],
...
@@ -154,6 +162,8 @@ generating answers to questions."""
...
@@ -154,6 +162,8 @@ generating answers to questions."""
"best_answer"
:
row
[
"Best Answer"
],
"best_answer"
:
row
[
"Best Answer"
],
# split on ";"
# split on ";"
"correct_answers"
:
row
[
"Correct Answers"
].
strip
().
split
(
";"
),
"correct_answers"
:
row
[
"Correct Answers"
].
strip
().
split
(
";"
),
"incorrect_answers"
:
row
[
"Incorrect Answers"
].
strip
().
split
(
";"
),
"incorrect_answers"
:
row
[
"Incorrect Answers"
]
.
strip
()
.
split
(
";"
),
"source"
:
row
[
"Source"
],
"source"
:
row
[
"Source"
],
}
}
lm_eval/datasets/unscramble/dataset_infos.json
View file @
121b7096
lm_eval/datasets/unscramble/unscramble.py
View file @
121b7096
...
@@ -64,8 +64,9 @@ class Unscramble(datasets.GeneratorBasedBuilder):
...
@@ -64,8 +64,9 @@ class Unscramble(datasets.GeneratorBasedBuilder):
VERSION
=
datasets
.
Version
(
"0.0.1"
)
VERSION
=
datasets
.
Version
(
"0.0.1"
)
BUILDER_CONFIGS
=
[
BUILDER_CONFIGS
=
[
datasets
.
BuilderConfig
(
name
=
name
,
version
=
version
,
datasets
.
BuilderConfig
(
description
=
_DESCRIPTIONS
[
name
])
name
=
name
,
version
=
version
,
description
=
_DESCRIPTIONS
[
name
]
)
for
name
,
version
in
zip
(
_NAMES
,
[
VERSION
]
*
len
(
_NAMES
))
for
name
,
version
in
zip
(
_NAMES
,
[
VERSION
]
*
len
(
_NAMES
))
]
]
...
...
lm_eval/datasets/wikitext/dataset_infos.json
View file @
121b7096
lm_eval/datasets/wikitext/wikitext.py
View file @
121b7096
...
@@ -123,86 +123,111 @@ class Wikitext(datasets.GeneratorBasedBuilder):
...
@@ -123,86 +123,111 @@ class Wikitext(datasets.GeneratorBasedBuilder):
return
[
return
[
datasets
.
SplitGenerator
(
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
name
=
datasets
.
Split
.
TEST
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
gen_kwargs
=
{
data_dir
,
"wiki.test.tokens"
),
"split"
:
"test"
},
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.tokens"
),
"split"
:
"test"
,
},
),
),
datasets
.
SplitGenerator
(
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TRAIN
,
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
gen_kwargs
=
{
data_dir
,
"wiki.train.tokens"
),
"split"
:
"train"
},
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.tokens"
),
"split"
:
"train"
,
},
),
),
datasets
.
SplitGenerator
(
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
VALIDATION
,
name
=
datasets
.
Split
.
VALIDATION
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
gen_kwargs
=
{
data_dir
,
"wiki.valid.tokens"
),
"split"
:
"valid"
},
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.tokens"
),
"split"
:
"valid"
,
},
),
),
]
]
else
:
else
:
if
self
.
config
.
name
==
"wikitext-103-raw-v1"
:
if
self
.
config
.
name
==
"wikitext-103-raw-v1"
:
data_file
=
dl_manager
.
download_and_extract
(
data_file
=
dl_manager
.
download_and_extract
(
self
.
config
.
data_url
)
self
.
config
.
data_url
)
data_dir
=
os
.
path
.
join
(
data_file
,
"wikitext-103-raw"
)
data_dir
=
os
.
path
.
join
(
data_file
,
"wikitext-103-raw"
)
return
[
return
[
datasets
.
SplitGenerator
(
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
name
=
datasets
.
Split
.
TEST
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
gen_kwargs
=
{
data_dir
,
"wiki.test.raw"
),
"split"
:
"test"
},
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.raw"
),
"split"
:
"test"
,
},
),
),
datasets
.
SplitGenerator
(
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TRAIN
,
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
gen_kwargs
=
{
data_dir
,
"wiki.train.raw"
),
"split"
:
"train"
},
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.raw"
),
"split"
:
"train"
,
},
),
),
datasets
.
SplitGenerator
(
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
VALIDATION
,
name
=
datasets
.
Split
.
VALIDATION
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
gen_kwargs
=
{
data_dir
,
"wiki.valid.raw"
),
"split"
:
"valid"
},
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.raw"
),
"split"
:
"valid"
,
},
),
),
]
]
else
:
else
:
if
self
.
config
.
name
==
"wikitext-2-raw-v1"
:
if
self
.
config
.
name
==
"wikitext-2-raw-v1"
:
data_file
=
dl_manager
.
download_and_extract
(
data_file
=
dl_manager
.
download_and_extract
(
self
.
config
.
data_url
)
self
.
config
.
data_url
)
data_dir
=
os
.
path
.
join
(
data_file
,
"wikitext-2-raw"
)
data_dir
=
os
.
path
.
join
(
data_file
,
"wikitext-2-raw"
)
return
[
return
[
datasets
.
SplitGenerator
(
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
name
=
datasets
.
Split
.
TEST
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
gen_kwargs
=
{
data_dir
,
"wiki.test.raw"
),
"split"
:
"test"
},
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.raw"
),
"split"
:
"test"
,
},
),
),
datasets
.
SplitGenerator
(
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TRAIN
,
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
gen_kwargs
=
{
data_dir
,
"wiki.train.raw"
),
"split"
:
"train"
},
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.raw"
),
"split"
:
"train"
,
},
),
),
datasets
.
SplitGenerator
(
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
VALIDATION
,
name
=
datasets
.
Split
.
VALIDATION
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
gen_kwargs
=
{
data_dir
,
"wiki.valid.raw"
),
"split"
:
"valid"
},
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.raw"
),
"split"
:
"valid"
,
},
),
),
]
]
else
:
else
:
if
self
.
config
.
name
==
"wikitext-2-v1"
:
if
self
.
config
.
name
==
"wikitext-2-v1"
:
data_file
=
dl_manager
.
download_and_extract
(
data_file
=
dl_manager
.
download_and_extract
(
self
.
config
.
data_url
)
self
.
config
.
data_url
)
data_dir
=
os
.
path
.
join
(
data_file
,
"wikitext-2"
)
data_dir
=
os
.
path
.
join
(
data_file
,
"wikitext-2"
)
return
[
return
[
datasets
.
SplitGenerator
(
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
name
=
datasets
.
Split
.
TEST
,
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
gen_kwargs
=
{
data_dir
,
"wiki.test.tokens"
),
"split"
:
"test"
},
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.test.tokens"
),
"split"
:
"test"
,
},
),
),
datasets
.
SplitGenerator
(
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TRAIN
,
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.tokens"
),
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.train.tokens"
),
"split"
:
"train"
,
"split"
:
"train"
,
},
},
),
),
datasets
.
SplitGenerator
(
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
VALIDATION
,
name
=
datasets
.
Split
.
VALIDATION
,
gen_kwargs
=
{
gen_kwargs
=
{
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.tokens"
),
"data_file"
:
os
.
path
.
join
(
data_dir
,
"wiki.valid.tokens"
),
"split"
:
"valid"
,
"split"
:
"valid"
,
},
},
),
),
...
@@ -216,12 +241,12 @@ class Wikitext(datasets.GeneratorBasedBuilder):
...
@@ -216,12 +241,12 @@ class Wikitext(datasets.GeneratorBasedBuilder):
data
=
f
.
read
().
split
(
"
\n
"
)
data
=
f
.
read
().
split
(
"
\n
"
)
for
line
in
data
:
for
line
in
data
:
rline
=
line
.
replace
(
"= = ="
,
"==="
).
replace
(
"= ="
,
"=="
).
strip
()
rline
=
line
.
replace
(
"= = ="
,
"==="
).
replace
(
"= ="
,
"=="
).
strip
()
if
rline
.
startswith
(
'
=
'
)
and
rline
.
strip
().
endswith
(
'
=
'
):
if
rline
.
startswith
(
"
=
"
)
and
rline
.
strip
().
endswith
(
"
=
"
):
page
=
'
\n
'
.
join
(
ret
)
page
=
"
\n
"
.
join
(
ret
)
if
page
.
strip
():
if
page
.
strip
():
yield
key
,
{
"page"
:
page
}
yield
key
,
{
"page"
:
page
}
key
+=
1
key
+=
1
ret
=
[]
ret
=
[]
ret
.
append
(
line
)
ret
.
append
(
line
)
page
=
'
\n
'
.
join
(
ret
)
page
=
"
\n
"
.
join
(
ret
)
yield
key
,
{
"page"
:
page
}
yield
key
,
{
"page"
:
page
}
lm_eval/decontamination/archiver.py
View file @
121b7096
...
@@ -8,12 +8,14 @@ import mmap
...
@@ -8,12 +8,14 @@ import mmap
import
tqdm
import
tqdm
from
pathlib
import
Path
from
pathlib
import
Path
def
json_serial
(
obj
):
def
json_serial
(
obj
):
"""JSON serializer for objects not serializable by default json code"""
"""JSON serializer for objects not serializable by default json code"""
if
isinstance
(
obj
,
(
datetime
.
datetime
,)):
if
isinstance
(
obj
,
(
datetime
.
datetime
,)):
return
obj
.
isoformat
()
return
obj
.
isoformat
()
raise
TypeError
(
"Type %s not serializable"
%
type
(
obj
))
raise
TypeError
(
"Type %s not serializable"
%
type
(
obj
))
# Modified version of lm_dataformat Archive for single file.
# Modified version of lm_dataformat Archive for single file.
class
Archive
:
class
Archive
:
...
@@ -22,25 +24,31 @@ class Archive:
...
@@ -22,25 +24,31 @@ class Archive:
dir_name
=
os
.
path
.
dirname
(
file_path
)
dir_name
=
os
.
path
.
dirname
(
file_path
)
if
dir_name
:
if
dir_name
:
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
self
.
fh
=
open
(
self
.
file_path
,
'
wb
'
)
self
.
fh
=
open
(
self
.
file_path
,
"
wb
"
)
self
.
cctx
=
zstandard
.
ZstdCompressor
(
level
=
compression_level
)
self
.
cctx
=
zstandard
.
ZstdCompressor
(
level
=
compression_level
)
self
.
compressor
=
self
.
cctx
.
stream_writer
(
self
.
fh
)
self
.
compressor
=
self
.
cctx
.
stream_writer
(
self
.
fh
)
def
add_data
(
self
,
data
,
meta
=
{}):
def
add_data
(
self
,
data
,
meta
=
{}):
self
.
compressor
.
write
(
json
.
dumps
({
'text'
:
data
,
'meta'
:
meta
},
default
=
json_serial
).
encode
(
'UTF-8'
)
+
b
'
\n
'
)
self
.
compressor
.
write
(
json
.
dumps
({
"text"
:
data
,
"meta"
:
meta
},
default
=
json_serial
).
encode
(
"UTF-8"
)
+
b
"
\n
"
)
def
commit
(
self
):
def
commit
(
self
):
self
.
compressor
.
flush
(
zstandard
.
FLUSH_FRAME
)
self
.
compressor
.
flush
(
zstandard
.
FLUSH_FRAME
)
self
.
fh
.
flush
()
self
.
fh
.
flush
()
self
.
fh
.
close
()
self
.
fh
.
close
()
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class
Reader
:
class
Reader
:
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
def
read
(
self
,
file
,
get_meta
=
False
,
autojoin_paragraphs
=
True
,
para_joiner
=
'
\n\n
'
):
def
read
(
self
,
file
,
get_meta
=
False
,
autojoin_paragraphs
=
True
,
para_joiner
=
"
\n\n
"
):
with
open
(
file
,
'
rb
'
)
as
fh
:
with
open
(
file
,
"
rb
"
)
as
fh
:
self
.
fh
=
fh
self
.
fh
=
fh
cctx
=
zstandard
.
ZstdDecompressor
()
cctx
=
zstandard
.
ZstdDecompressor
()
reader
=
io
.
BufferedReader
(
cctx
.
stream_reader
(
fh
))
reader
=
io
.
BufferedReader
(
cctx
.
stream_reader
(
fh
))
...
@@ -52,16 +60,17 @@ class Reader:
...
@@ -52,16 +60,17 @@ class Reader:
yield
ob
yield
ob
continue
continue
text
=
ob
[
'
text
'
]
text
=
ob
[
"
text
"
]
if
autojoin_paragraphs
and
isinstance
(
text
,
list
):
if
autojoin_paragraphs
and
isinstance
(
text
,
list
):
text
=
para_joiner
.
join
(
text
)
text
=
para_joiner
.
join
(
text
)
if
get_meta
:
if
get_meta
:
yield
text
,
(
ob
[
'
meta
'
]
if
'
meta
'
in
ob
else
{})
yield
text
,
(
ob
[
"
meta
"
]
if
"
meta
"
in
ob
else
{})
else
:
else
:
yield
text
yield
text
class
TextArchive
:
class
TextArchive
:
def
__init__
(
self
,
file_path
,
mode
=
"rb+"
):
def
__init__
(
self
,
file_path
,
mode
=
"rb+"
):
self
.
file_path
=
file_path
self
.
file_path
=
file_path
...
@@ -75,12 +84,13 @@ class TextArchive:
...
@@ -75,12 +84,13 @@ class TextArchive:
self
.
fh
=
open
(
self
.
file_path
,
mode
)
self
.
fh
=
open
(
self
.
file_path
,
mode
)
def
add_data
(
self
,
data
):
def
add_data
(
self
,
data
):
self
.
fh
.
write
(
data
.
encode
(
'
UTF-8
'
)
+
b
'
\n
'
)
self
.
fh
.
write
(
data
.
encode
(
"
UTF-8
"
)
+
b
"
\n
"
)
def
commit
(
self
):
def
commit
(
self
):
self
.
fh
.
flush
()
self
.
fh
.
flush
()
self
.
fh
.
close
()
self
.
fh
.
close
()
class
TextReader
:
class
TextReader
:
def
__init__
(
self
,
file_path
):
def
__init__
(
self
,
file_path
):
self
.
file_path
=
file_path
self
.
file_path
=
file_path
...
@@ -90,9 +100,12 @@ class TextReader:
...
@@ -90,9 +100,12 @@ class TextReader:
def
read_tqdm
(
self
,
update_frequency
=
10000
):
def
read_tqdm
(
self
,
update_frequency
=
10000
):
current_file_position
=
0
current_file_position
=
0
line_counter
=
0
line_counter
=
0
with
open
(
self
.
file_path
,
'r'
)
as
fh
,
\
with
open
(
self
.
file_path
,
"r"
)
as
fh
,
tqdm
.
tqdm
(
tqdm
.
tqdm
(
total
=
os
.
path
.
getsize
(
self
.
file_path
),
dynamic_ncols
=
True
,
total
=
os
.
path
.
getsize
(
self
.
file_path
),
unit
=
"byte"
,
unit_scale
=
1
)
as
progress
:
dynamic_ncols
=
True
,
unit
=
"byte"
,
unit_scale
=
1
,
)
as
progress
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
line
=
line
.
decode
(
"utf-8"
)
...
@@ -107,7 +120,7 @@ class TextReader:
...
@@ -107,7 +120,7 @@ class TextReader:
def
read_and_tell
(
self
):
def
read_and_tell
(
self
):
current_file_position
=
0
current_file_position
=
0
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
with
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
line
=
line
.
decode
(
"utf-8"
)
...
@@ -117,14 +130,14 @@ class TextReader:
...
@@ -117,14 +130,14 @@ class TextReader:
yield
line
[:
-
1
],
raw_bytes_read
yield
line
[:
-
1
],
raw_bytes_read
def
read
(
self
):
def
read
(
self
):
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
with
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf8"
)
as
fh
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
with
mmap
.
mmap
(
fh
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_obj
:
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
for
line
in
iter
(
mmap_obj
.
readline
,
b
""
):
line
=
line
.
decode
(
"utf-8"
)
line
=
line
.
decode
(
"utf-8"
)
yield
line
[:
-
1
]
yield
line
[:
-
1
]
def
read_slow
(
self
):
def
read_slow
(
self
):
with
open
(
self
.
file_path
,
'r'
,
encoding
=
"utf8"
)
as
fh
:
with
open
(
self
.
file_path
,
"r"
,
encoding
=
"utf8"
)
as
fh
:
while
True
:
while
True
:
line
=
fh
.
readline
()
line
=
fh
.
readline
()
if
line
==
-
1
or
line
==
""
:
if
line
==
-
1
or
line
==
""
:
...
@@ -132,6 +145,7 @@ class TextReader:
...
@@ -132,6 +145,7 @@ class TextReader:
else
:
else
:
yield
line
[:
-
1
]
yield
line
[:
-
1
]
# Optimized for speed. Decompresses the archive in shell before
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
# using the mmap'd TextReader.
class
ZStdTextReader
:
class
ZStdTextReader
:
...
...
lm_eval/decontamination/decontaminate.py
View file @
121b7096
...
@@ -57,19 +57,27 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
...
@@ -57,19 +57,27 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
os
.
mkdir
(
f
"data/
{
task_name
}
"
)
os
.
mkdir
(
f
"data/
{
task_name
}
"
)
# Check if we've decontaminated this combination before
# Check if we've decontaminated this combination before
overlaps_dump_path
=
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
)
overlaps_dump_path
=
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
)
if
os
.
path
.
exists
(
overlaps_dump_path
):
if
os
.
path
.
exists
(
overlaps_dump_path
):
duplicates
[(
task_name
,
task_set
)]
=
pickle
.
load
(
open
(
overlaps_dump_path
,
"rb"
))
duplicates
[(
task_name
,
task_set
)]
=
pickle
.
load
(
open
(
overlaps_dump_path
,
"rb"
)
)
sets_to_decontaminate
-=
1
sets_to_decontaminate
-=
1
continue
continue
else
:
else
:
duplicates
[(
task_name
,
task_set
)]
=
set
()
duplicates
[(
task_name
,
task_set
)]
=
set
()
# Build/load the task lookup {ngram: set(documents)}.
# Build/load the task lookup {ngram: set(documents)}.
task_set_lookup_path
=
f
"data/
{
task_name
}
/
{
task_set
}
_
{
ngrams_n_size
}
grams_limit
{
limit
}
.lookup"
task_set_lookup_path
=
(
f
"data/
{
task_name
}
/
{
task_set
}
_
{
ngrams_n_size
}
grams_limit
{
limit
}
.lookup"
)
if
os
.
path
.
exists
(
task_set_lookup_path
):
if
os
.
path
.
exists
(
task_set_lookup_path
):
print
(
f
"
{
task_set_lookup_path
}
available, loading..."
)
print
(
f
"
{
task_set_lookup_path
}
available, loading..."
)
lookups
[(
task_name
,
task_set
)]
=
pickle
.
load
(
open
(
task_set_lookup_path
,
"rb"
))
lookups
[(
task_name
,
task_set
)]
=
pickle
.
load
(
open
(
task_set_lookup_path
,
"rb"
)
)
else
:
else
:
print
(
f
"
{
task_set_lookup_path
}
not available, building..."
)
print
(
f
"
{
task_set_lookup_path
}
not available, building..."
)
lookup
=
collections
.
defaultdict
(
set
)
lookup
=
collections
.
defaultdict
(
set
)
...
@@ -79,7 +87,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
...
@@ -79,7 +87,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
for
ngram
in
ngrams
:
for
ngram
in
ngrams
:
lookup
[
ngram
].
add
(
doc_id
)
lookup
[
ngram
].
add
(
doc_id
)
pickle
.
dump
(
lookup
,
open
(
task_set_lookup_path
,
"wb"
))
pickle
.
dump
(
lookup
,
open
(
task_set_lookup_path
,
"wb"
))
lookups
[(
task_name
,
task_set
)]
=
lookup
lookups
[(
task_name
,
task_set
)]
=
lookup
elapsed
=
time
.
perf_counter
()
-
start
elapsed
=
time
.
perf_counter
()
-
start
...
@@ -115,7 +123,9 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
...
@@ -115,7 +123,9 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
for
line
in
reader
.
read_tqdm
():
# Scan training set ngrams file
for
line
in
reader
.
read_tqdm
():
# Scan training set ngrams file
total_ngrams
+=
1
total_ngrams
+=
1
[
ngram
,
document_id
]
=
line
.
rsplit
(
" "
,
1
)
[
ngram
,
document_id
]
=
line
.
rsplit
(
" "
,
1
)
if
ngram
!=
current_ngram
:
# Only need to match the ngram once in training set
if
(
ngram
!=
current_ngram
):
# Only need to match the ngram once in training set
unique_ngrams
+=
1
unique_ngrams
+=
1
current_ngram
=
ngram
current_ngram
=
ngram
if
ngram
in
merged_lookup
:
if
ngram
in
merged_lookup
:
...
@@ -123,7 +133,11 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
...
@@ -123,7 +133,11 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
matching_unique
+=
1
matching_unique
+=
1
for
task_name
,
task_set
,
doc_ids
in
merged_lookup
[
ngram
]:
for
task_name
,
task_set
,
doc_ids
in
merged_lookup
[
ngram
]:
task_doc_set
=
duplicates
[(
task_name
,
task_set
)]
task_doc_set
=
duplicates
[(
task_name
,
task_set
)]
for
doc_id
in
doc_ids
:
# Record contamination across all relevant task/set combos
for
(
doc_id
)
in
(
doc_ids
):
# Record contamination across all relevant task/set combos
task_doc_set
.
add
(
doc_id
)
task_doc_set
.
add
(
doc_id
)
del
merged_lookup
[
ngram
]
# No point matching again
del
merged_lookup
[
ngram
]
# No point matching again
else
:
else
:
...
@@ -145,9 +159,10 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
...
@@ -145,9 +159,10 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# Dump overlaps separately
# Dump overlaps separately
for
(
task_name
,
task_set
),
doc_ids
in
duplicates
.
items
():
for
(
task_name
,
task_set
),
doc_ids
in
duplicates
.
items
():
overlaps_dump_path
=
get_overlaps_dump_path
(
task_name
,
task_set
,
ngrams_n_size
,
limit
)
overlaps_dump_path
=
get_overlaps_dump_path
(
pickle
.
dump
(
doc_ids
,
open
(
overlaps_dump_path
,
"wb"
))
task_name
,
task_set
,
ngrams_n_size
,
limit
)
pickle
.
dump
(
doc_ids
,
open
(
overlaps_dump_path
,
"wb"
))
# Strip task set and return
# Strip task set and return
return
{
task_name
:
doc_ids
for
(
task_name
,
task_set
),
doc_ids
in
duplicates
.
items
()}
return
{
task_name
:
doc_ids
for
(
task_name
,
task_set
),
doc_ids
in
duplicates
.
items
()}
lm_eval/decontamination/janitor.py
View file @
121b7096
...
@@ -9,6 +9,7 @@ from pprint import pprint
...
@@ -9,6 +9,7 @@ from pprint import pprint
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
try
:
try
:
import
janitor_util
import
janitor_util
JANITOR_CPP
=
True
JANITOR_CPP
=
True
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"WARNING: C++ module could not be loaded. Janitor running in python mode"
)
print
(
"WARNING: C++ module could not be loaded. Janitor running in python mode"
)
...
@@ -41,6 +42,7 @@ def word_ngrams(s, n):
...
@@ -41,6 +42,7 @@ def word_ngrams(s, n):
ngram_seqs
=
form_ngrams
(
iter
(
tokens
),
n
)
ngram_seqs
=
form_ngrams
(
iter
(
tokens
),
n
)
return
(
" "
.
join
(
ngram
)
for
ngram
in
ngram_seqs
)
return
(
" "
.
join
(
ngram
)
for
ngram
in
ngram_seqs
)
# Does character sequences only - combined faster function to play around with later
# Does character sequences only - combined faster function to play around with later
# def word_ngrams_indices_combined(sequence, n):
# def word_ngrams_indices_combined(sequence, n):
# current_word = ""
# current_word = ""
...
@@ -70,7 +72,7 @@ def split_indices(s):
...
@@ -70,7 +72,7 @@ def split_indices(s):
"""Splits a string on whitespaces and records the indices of each in the original string.
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
@:return generator((word, (start_idx, end_idx)), ...)
"""
"""
return
((
m
.
group
(
0
),
(
m
.
start
(),
m
.
end
()
-
1
))
for
m
in
re
.
finditer
(
r
'
\S+
'
,
s
))
return
((
m
.
group
(
0
),
(
m
.
start
(),
m
.
end
()
-
1
))
for
m
in
re
.
finditer
(
r
"
\S+
"
,
s
))
def
word_ngrams_indices
(
s
,
n
):
def
word_ngrams_indices
(
s
,
n
):
...
@@ -90,10 +92,15 @@ def word_ngrams_indices(s, n):
...
@@ -90,10 +92,15 @@ def word_ngrams_indices(s, n):
# ([word, word, ...], [(start,end), (start,end), ...]),
# ([word, word, ...], [(start,end), (start,end), ...]),
# ...
# ...
# )
# )
ngram_indices_pairs
=
(
zip
(
*
ngram_with_indices
)
for
ngram_with_indices
in
ngram_seqs_with_indices
)
ngram_indices_pairs
=
(
zip
(
*
ngram_with_indices
)
for
ngram_with_indices
in
ngram_seqs_with_indices
)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return
((
" "
.
join
(
ngram_seq
),
(
indices
[
0
][
0
],
indices
[
-
1
][
1
]))
for
ngram_seq
,
indices
in
ngram_indices_pairs
)
return
(
(
" "
.
join
(
ngram_seq
),
(
indices
[
0
][
0
],
indices
[
-
1
][
1
]))
for
ngram_seq
,
indices
in
ngram_indices_pairs
)
class
Janitor
:
class
Janitor
:
...
@@ -105,7 +112,7 @@ class Janitor:
...
@@ -105,7 +112,7 @@ class Janitor:
window_to_remove
=
200
,
window_to_remove
=
200
,
too_dirty_cutoff
=
10
,
too_dirty_cutoff
=
10
,
minimum_slice_length
=
200
,
minimum_slice_length
=
200
,
delete_chars
=
string
.
punctuation
delete_chars
=
string
.
punctuation
,
):
):
self
.
ngram_n
=
ngram_n
self
.
ngram_n
=
ngram_n
self
.
window_to_remove
=
window_to_remove
self
.
window_to_remove
=
window_to_remove
...
@@ -121,7 +128,7 @@ class Janitor:
...
@@ -121,7 +128,7 @@ class Janitor:
self
.
translation_table
=
str
.
maketrans
(
self
.
translation_table
=
str
.
maketrans
(
string
.
ascii_lowercase
+
string
.
ascii_uppercase
,
# These characters
string
.
ascii_lowercase
+
string
.
ascii_uppercase
,
# These characters
string
.
ascii_lowercase
*
2
,
# Become these characters
string
.
ascii_lowercase
*
2
,
# Become these characters
self
.
delete_chars
# These are deleted
self
.
delete_chars
,
# These are deleted
)
)
##############
##############
...
@@ -129,14 +136,13 @@ class Janitor:
...
@@ -129,14 +136,13 @@ class Janitor:
##############
##############
def
save_contamination_ngrams
(
self
,
filename
):
def
save_contamination_ngrams
(
self
,
filename
):
with
open
(
filename
,
'
wb
'
)
as
fp
:
with
open
(
filename
,
"
wb
"
)
as
fp
:
pickle
.
dump
(
filename
,
fp
)
pickle
.
dump
(
filename
,
fp
)
def
load_contamination_ngrams
(
self
,
filename
):
def
load_contamination_ngrams
(
self
,
filename
):
with
open
(
filename
,
'
rb
'
)
as
fp
:
with
open
(
filename
,
"
rb
"
)
as
fp
:
self
.
dirt_ngrams
=
pickle
.
load
(
fp
)
self
.
dirt_ngrams
=
pickle
.
load
(
fp
)
##############
##############
# Call these :)
# Call these :)
##############
##############
...
@@ -171,11 +177,11 @@ class Janitor:
...
@@ -171,11 +177,11 @@ class Janitor:
end
=
min
(
len
(
dirty_string
),
end
+
self
.
window_to_remove
)
end
=
min
(
len
(
dirty_string
),
end
+
self
.
window_to_remove
)
if
start
-
splice_idx
>
self
.
minimum_slice_length
:
if
start
-
splice_idx
>
self
.
minimum_slice_length
:
clean_chunks
.
append
(
dirty_string
[
splice_idx
:
start
])
clean_chunks
.
append
(
dirty_string
[
splice_idx
:
start
])
splice_idx
=
end
splice_idx
=
end
if
end
<
len
(
dirty_string
)
-
self
.
minimum_slice_length
:
if
end
<
len
(
dirty_string
)
-
self
.
minimum_slice_length
:
clean_chunks
.
append
(
dirty_string
[
end
+
1
:])
clean_chunks
.
append
(
dirty_string
[
end
+
1
:])
return
clean_chunks
return
clean_chunks
...
@@ -184,10 +190,14 @@ class Janitor:
...
@@ -184,10 +190,14 @@ class Janitor:
##############
##############
def
register_contaminant_cpp
(
self
,
dirt_string
):
def
register_contaminant_cpp
(
self
,
dirt_string
):
self
.
dirt_ngrams
.
update
(
janitor_util
.
clean_ngram
(
dirt_string
,
self
.
delete_chars
,
self
.
ngram_n
))
self
.
dirt_ngrams
.
update
(
janitor_util
.
clean_ngram
(
dirt_string
,
self
.
delete_chars
,
self
.
ngram_n
)
)
def
clean_cpp
(
self
,
dirty_string
):
def
clean_cpp
(
self
,
dirty_string
):
contamination_indices
=
janitor_util
.
clean_ngram_with_indices
(
dirty_string
,
self
.
delete_chars
,
self
.
ngram_n
)
contamination_indices
=
janitor_util
.
clean_ngram_with_indices
(
dirty_string
,
self
.
delete_chars
,
self
.
ngram_n
)
return
self
.
_split_chunks
(
dirty_string
,
contamination_indices
)
return
self
.
_split_chunks
(
dirty_string
,
contamination_indices
)
##############
##############
...
@@ -198,7 +208,9 @@ class Janitor:
...
@@ -198,7 +208,9 @@ class Janitor:
return
s
.
translate
(
self
.
translation_table
)
return
s
.
translate
(
self
.
translation_table
)
def
register_contaminant_python
(
self
,
dirt_string
):
def
register_contaminant_python
(
self
,
dirt_string
):
self
.
dirt_ngrams
.
update
(
word_ngrams
(
self
.
normalize_string
(
dirt_string
),
self
.
ngram_n
))
self
.
dirt_ngrams
.
update
(
word_ngrams
(
self
.
normalize_string
(
dirt_string
),
self
.
ngram_n
)
)
def
clean_python
(
self
,
dirty_string
):
def
clean_python
(
self
,
dirty_string
):
contamination_indices
=
(
contamination_indices
=
(
...
...
lm_eval/evaluator.py
View file @
121b7096
...
@@ -11,12 +11,22 @@ import numpy as np
...
@@ -11,12 +11,22 @@ import numpy as np
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
from
lm_eval.decontamination.decontaminate
import
get_train_overlap
from
lm_eval.decontamination.decontaminate
import
get_train_overlap
@
positional_deprecated
@
positional_deprecated
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
def
simple_evaluate
(
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
model
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
model_args
=
None
,
description_dict
=
None
,
check_integrity
=
False
,
tasks
=
[],
decontamination_ngrams_path
=
None
):
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
check_integrity
=
False
,
decontamination_ngrams_path
=
None
,
):
"""Instantiate and evaluate a model on a list of tasks.
"""Instantiate and evaluate a model on a list of tasks.
...
@@ -52,17 +62,23 @@ def simple_evaluate(model, model_args=None, tasks=[],
...
@@ -52,17 +62,23 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert
tasks
!=
[],
"No tasks specified"
assert
tasks
!=
[],
"No tasks specified"
if
isinstance
(
model
,
str
):
if
isinstance
(
model
,
str
):
if
model_args
is
None
:
model_args
=
""
if
model_args
is
None
:
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
model_args
=
""
'batch_size'
:
batch_size
,
'device'
:
device
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
})
model_args
,
{
"batch_size"
:
batch_size
,
"device"
:
device
}
)
else
:
else
:
assert
isinstance
(
model
,
lm_eval
.
base
.
LM
)
assert
isinstance
(
model
,
lm_eval
.
base
.
LM
)
lm
=
model
lm
=
model
if
not
no_cache
:
if
not
no_cache
:
lm
=
lm_eval
.
base
.
CachingLM
(
lm
=
lm_eval
.
base
.
CachingLM
(
lm
,
'lm_cache/'
+
model
+
'_'
+
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
).
replace
(
'/'
,
'-'
)
+
'.db'
lm
,
"lm_cache/"
+
model
+
"_"
+
model_args
.
replace
(
"="
,
"-"
).
replace
(
","
,
"_"
).
replace
(
"/"
,
"-"
)
+
".db"
,
)
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
...
@@ -89,16 +105,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
...
@@ -89,16 +105,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache"
:
no_cache
,
"no_cache"
:
no_cache
,
"limit"
:
limit
,
"limit"
:
limit
,
"bootstrap_iters"
:
bootstrap_iters
,
"bootstrap_iters"
:
bootstrap_iters
,
"description_dict"
:
description_dict
"description_dict"
:
description_dict
,
}
}
return
results
return
results
decontaminate_suffix
=
"_decontaminate"
decontaminate_suffix
=
"_decontaminate"
@
positional_deprecated
@
positional_deprecated
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
def
evaluate
(
decontamination_ngrams_path
=
None
):
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
decontamination_ngrams_path
=
None
,
):
"""Instantiate and evaluate a model on a list of tasks.
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
:param lm: obj
...
@@ -124,14 +150,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -124,14 +150,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert
not
provide_description
# not implemented.
assert
not
provide_description
# not implemented.
if
provide_description
is
not
None
:
if
provide_description
is
not
None
:
# nudge people to not specify it at all
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
decontaminate
=
decontamination_ngrams_path
is
not
None
decontaminate
=
decontamination_ngrams_path
is
not
None
task_dict_items
=
[
task_dict_items
=
[
(
name
,
task
)
(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
for
name
,
task
in
task_dict
.
items
()
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
]
]
results
=
collections
.
defaultdict
(
dict
)
results
=
collections
.
defaultdict
(
dict
)
...
@@ -172,19 +200,22 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -172,19 +200,22 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd
.
seed
(
42
)
rnd
.
seed
(
42
)
rnd
.
shuffle
(
task_docs
)
rnd
.
shuffle
(
task_docs
)
description
=
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
description
=
(
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
)
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
if
decontaminate
and
task
.
should_decontaminate
():
if
decontaminate
and
task
.
should_decontaminate
():
docs_for_decontamination
[(
task_name
,
task_set
)].
append
(
task
.
doc_to_decontamination_query
(
doc
))
docs_for_decontamination
[(
task_name
,
task_set
)].
append
(
task
.
doc_to_decontamination_query
(
doc
)
)
docs
[(
task_name
,
doc_id
)]
=
doc
docs
[(
task_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
ctx
=
task
.
fewshot_context
(
doc
=
doc
,
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
...
@@ -198,7 +229,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -198,7 +229,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan
# Compare all tasks/sets at once to ensure a single training set scan
if
decontaminate
:
if
decontaminate
:
print
(
"Finding train/test overlap, please wait..."
)
print
(
"Finding train/test overlap, please wait..."
)
overlaps
=
get_train_overlap
(
docs_for_decontamination
,
decontamination_ngrams_path
,
limit
)
overlaps
=
get_train_overlap
(
docs_for_decontamination
,
decontamination_ngrams_path
,
limit
)
# all responses for each (task, doc)
# all responses for each (task, doc)
process_res_queue
=
collections
.
defaultdict
(
list
)
process_res_queue
=
collections
.
defaultdict
(
list
)
...
@@ -212,7 +245,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -212,7 +245,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print
(
"Running"
,
reqtype
,
"requests"
)
print
(
"Running"
,
reqtype
,
"requests"
)
resps
=
getattr
(
lm
,
reqtype
)([
req
.
args
for
req
in
reqs
])
resps
=
getattr
(
lm
,
reqtype
)([
req
.
args
for
req
in
reqs
])
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)]
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
]
for
resp
,
(
i
,
task_name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
for
resp
,
(
i
,
task_name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
process_res_queue
[(
task_name
,
doc_id
)].
append
((
i
,
resp
))
process_res_queue
[(
task_name
,
doc_id
)].
append
((
i
,
resp
))
...
@@ -241,7 +276,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -241,7 +276,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
task
=
task_dict
[
task_name
]
task
=
task_dict
[
task_name
]
real_metric
=
metric
# key when looking up the metric with task.aggregation
real_metric
=
metric
# key when looking up the metric with task.aggregation
if
metric
.
endswith
(
decontaminate_suffix
):
if
metric
.
endswith
(
decontaminate_suffix
):
real_metric
=
metric
.
replace
(
decontaminate_suffix
,
""
)
# decontaminated still uses the same metric
real_metric
=
metric
.
replace
(
decontaminate_suffix
,
""
)
# decontaminated still uses the same metric
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
real_metric
](
items
)
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
real_metric
](
items
)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
...
@@ -249,16 +286,15 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -249,16 +286,15 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
metric
=
task
.
aggregation
()[
real_metric
],
metric
=
task
.
aggregation
()[
real_metric
],
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
)
)
if
stderr
is
not
None
:
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
return
{
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)}
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)
}
def
make_table
(
result_dict
):
def
make_table
(
result_dict
):
...
@@ -280,9 +316,9 @@ def make_table(result_dict):
...
@@ -280,9 +316,9 @@ def make_table(result_dict):
if
m
+
"_stderr"
in
dic
:
if
m
+
"_stderr"
in
dic
:
se
=
dic
[
m
+
"_stderr"
]
se
=
dic
[
m
+
"_stderr"
]
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
'±'
,
'
%.4f
'
%
se
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
"±"
,
"
%.4f
"
%
se
])
else
:
else
:
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
''
,
''
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
""
,
""
])
k
=
""
k
=
""
version
=
""
version
=
""
md_writer
.
value_matrix
=
values
md_writer
.
value_matrix
=
values
...
...
lm_eval/metrics.py
View file @
121b7096
...
@@ -103,6 +103,7 @@ def weighted_mean(items):
...
@@ -103,6 +103,7 @@ def weighted_mean(items):
def
weighted_perplexity
(
items
):
def
weighted_perplexity
(
items
):
return
math
.
exp
(
-
weighted_mean
(
items
))
return
math
.
exp
(
-
weighted_mean
(
items
))
def
bits_per_byte
(
items
):
def
bits_per_byte
(
items
):
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
...
@@ -184,8 +185,10 @@ def _sacreformat(refs, preds):
...
@@ -184,8 +185,10 @@ def _sacreformat(refs, preds):
return
refs
,
preds
return
refs
,
preds
# stderr stuff
# stderr stuff
class
_bootstrap_internal
:
class
_bootstrap_internal
:
def
__init__
(
self
,
f
,
n
):
def
__init__
(
self
,
f
,
n
):
self
.
f
=
f
self
.
f
=
f
...
@@ -203,6 +206,7 @@ class _bootstrap_internal:
...
@@ -203,6 +206,7 @@ class _bootstrap_internal:
def
bootstrap_stderr
(
f
,
xs
,
iters
):
def
bootstrap_stderr
(
f
,
xs
,
iters
):
import
multiprocessing
as
mp
import
multiprocessing
as
mp
pool
=
mp
.
Pool
(
mp
.
cpu_count
())
pool
=
mp
.
Pool
(
mp
.
cpu_count
())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# equivalent to stderr calculated without Bessel's correction in the stddev.
# equivalent to stderr calculated without Bessel's correction in the stddev.
...
@@ -213,10 +217,15 @@ def bootstrap_stderr(f, xs, iters):
...
@@ -213,10 +217,15 @@ def bootstrap_stderr(f, xs, iters):
res
=
[]
res
=
[]
chunk_size
=
min
(
1000
,
iters
)
chunk_size
=
min
(
1000
,
iters
)
from
tqdm
import
tqdm
from
tqdm
import
tqdm
print
(
"bootstrapping for stddev:"
,
f
.
__name__
)
print
(
"bootstrapping for stddev:"
,
f
.
__name__
)
for
bootstrap
in
tqdm
(
pool
.
imap
(
for
bootstrap
in
tqdm
(
pool
.
imap
(
_bootstrap_internal
(
f
,
chunk_size
),
_bootstrap_internal
(
f
,
chunk_size
),
[(
i
,
xs
)
for
i
in
range
(
iters
//
chunk_size
)]),
total
=
iters
//
chunk_size
):
[(
i
,
xs
)
for
i
in
range
(
iters
//
chunk_size
)],
),
total
=
iters
//
chunk_size
,
):
# sample w replacement
# sample w replacement
res
.
extend
(
bootstrap
)
res
.
extend
(
bootstrap
)
...
@@ -238,17 +247,13 @@ def stderr_for_metric(metric, bootstrap_iters):
...
@@ -238,17 +247,13 @@ def stderr_for_metric(metric, bootstrap_iters):
if
metric
in
bootstrappable
:
if
metric
in
bootstrappable
:
return
lambda
x
:
bootstrap_stderr
(
metric
,
x
,
iters
=
bootstrap_iters
)
return
lambda
x
:
bootstrap_stderr
(
metric
,
x
,
iters
=
bootstrap_iters
)
stderr
=
{
stderr
=
{
mean
:
mean_stderr
,
acc_all
:
acc_all_stderr
}
mean
:
mean_stderr
,
acc_all
:
acc_all_stderr
}
return
stderr
.
get
(
metric
,
None
)
return
stderr
.
get
(
metric
,
None
)
def
yesno
(
x
):
def
yesno
(
x
):
if
x
:
if
x
:
return
'
yes
'
return
"
yes
"
else
:
else
:
return
'
no
'
return
"
no
"
lm_eval/models/gpt2.py
View file @
121b7096
...
@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM
...
@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM
class
HFLM
(
BaseLM
):
class
HFLM
(
BaseLM
):
def
__init__
(
def
__init__
(
self
,
device
=
'cuda'
,
pretrained
=
'gpt2'
,
revision
=
'main'
,
subfolder
=
None
,
tokenizer
=
None
,
batch_size
=
1
):
self
,
device
=
"cuda"
,
pretrained
=
"gpt2"
,
revision
=
"main"
,
subfolder
=
None
,
tokenizer
=
None
,
batch_size
=
1
,
):
super
().
__init__
()
super
().
__init__
()
assert
isinstance
(
device
,
str
)
assert
isinstance
(
device
,
str
)
...
@@ -20,28 +27,47 @@ class HFLM(BaseLM):
...
@@ -20,28 +27,47 @@ class HFLM(BaseLM):
else
:
else
:
print
(
"Device not specificed"
)
print
(
"Device not specificed"
)
print
(
f
"Cuda Available?
{
torch
.
cuda
.
is_available
()
}
"
)
print
(
f
"Cuda Available?
{
torch
.
cuda
.
is_available
()
}
"
)
self
.
_device
=
torch
.
device
(
'cuda'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
self
.
_device
=
(
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
)
# TODO: update this to be less of a hack once subfolder is fixed in HF
# TODO: update this to be less of a hack once subfolder is fixed in HF
self
.
gpt2
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
self
.
gpt2
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
pretrained
,
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
)
pretrained
,
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
),
).
to
(
self
.
device
)
).
to
(
self
.
device
)
self
.
gpt2
.
eval
()
self
.
gpt2
.
eval
()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
pretrained
if
tokenizer
is
None
else
tokenizer
,
revision
=
revision
,
subfolder
=
subfolder
)
pretrained
if
tokenizer
is
None
else
tokenizer
,
revision
=
revision
,
subfolder
=
subfolder
,
)
assert
isinstance
(
self
.
tokenizer
,
(
assert
isinstance
(
transformers
.
GPT2Tokenizer
,
transformers
.
GPT2TokenizerFast
,
self
.
tokenizer
,
transformers
.
T5Tokenizer
,
transformers
.
T5TokenizerFast
,
(
)),
"this tokenizer has not been checked for compatibility yet!"
transformers
.
GPT2Tokenizer
,
transformers
.
GPT2TokenizerFast
,
transformers
.
T5Tokenizer
,
transformers
.
T5TokenizerFast
,
),
),
"this tokenizer has not been checked for compatibility yet!"
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
if
isinstance
(
self
.
tokenizer
,
(
transformers
.
GPT2Tokenizer
,
transformers
.
GPT2TokenizerFast
)):
if
isinstance
(
assert
self
.
tokenizer
.
encode
(
'hello
\n\n
hello'
)
==
[
31373
,
198
,
198
,
31373
],
\
self
.
tokenizer
,
(
transformers
.
GPT2Tokenizer
,
transformers
.
GPT2TokenizerFast
)
self
.
tokenizer
.
encode
(
'hello
\n\n
hello'
)
):
assert
self
.
tokenizer
.
encode
(
"hello
\n\n
hello"
)
==
[
31373
,
198
,
198
,
31373
,
],
self
.
tokenizer
.
encode
(
"hello
\n\n
hello"
)
# multithreading and batching
# multithreading and batching
self
.
batch_size_per_gpu
=
batch_size
# todo: adaptive batch size
self
.
batch_size_per_gpu
=
batch_size
# todo: adaptive batch size
...
@@ -97,10 +123,7 @@ class HFLM(BaseLM):
...
@@ -97,10 +123,7 @@ class HFLM(BaseLM):
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
return
self
.
gpt2
.
generate
(
return
self
.
gpt2
.
generate
(
context
,
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
False
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
False
)
)
...
...
lm_eval/models/gpt3.py
View file @
121b7096
...
@@ -36,17 +36,19 @@ def get_result(response, ctxlen):
...
@@ -36,17 +36,19 @@ def get_result(response, ctxlen):
def
oa_completion
(
**
kwargs
):
def
oa_completion
(
**
kwargs
):
"""
Query OpenAI API for completion.
"""Query OpenAI API for completion.
Retry with back-off until they respond
Retry with back-off until they respond
"""
"""
import
openai
import
openai
backoff_time
=
3
backoff_time
=
3
while
True
:
while
True
:
try
:
try
:
return
openai
.
Completion
.
create
(
**
kwargs
)
return
openai
.
Completion
.
create
(
**
kwargs
)
except
openai
.
error
.
OpenAIError
:
except
openai
.
error
.
OpenAIError
:
import
traceback
import
traceback
traceback
.
print_exc
()
traceback
.
print_exc
()
time
.
sleep
(
backoff_time
)
time
.
sleep
(
backoff_time
)
backoff_time
*=
1.5
backoff_time
*=
1.5
...
@@ -66,16 +68,19 @@ class GPT3LM(BaseLM):
...
@@ -66,16 +68,19 @@ class GPT3LM(BaseLM):
super
().
__init__
()
super
().
__init__
()
import
openai
import
openai
self
.
engine
=
engine
self
.
engine
=
engine
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
'
gpt2
'
)
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
"
gpt2
"
)
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
# to make the annoying "Using pad_token, but it is not set yet." error go away
# to make the annoying "Using pad_token, but it is not set yet." error go away
self
.
tokenizer
.
pad_token
=
"<|endoftext|>"
self
.
tokenizer
.
pad_token
=
"<|endoftext|>"
assert
self
.
tokenizer
.
encode
(
'
hello
\n\n
hello
'
)
==
[
31373
,
198
,
198
,
31373
]
assert
self
.
tokenizer
.
encode
(
"
hello
\n\n
hello
"
)
==
[
31373
,
198
,
198
,
31373
]
self
.
truncate
=
truncate
self
.
truncate
=
truncate
self
.
end_of_text_token_id
=
self
.
tokenizer
.
convert_tokens_to_ids
([
"<|endoftext|>"
])[
0
]
self
.
end_of_text_token_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
[
"<|endoftext|>"
]
)[
0
]
# Read from environment variable OPENAI_API_SECRET_KEY
# Read from environment variable OPENAI_API_SECRET_KEY
openai
.
api_key
=
os
.
environ
[
"OPENAI_API_SECRET_KEY"
]
openai
.
api_key
=
os
.
environ
[
"OPENAI_API_SECRET_KEY"
]
...
@@ -121,14 +126,19 @@ class GPT3LM(BaseLM):
...
@@ -121,14 +126,19 @@ class GPT3LM(BaseLM):
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
tqdm
(
list
(
utils
.
chunks
(
reord
.
get_reordered
(),
self
.
REQ_CHUNK_SIZE
)),
disable
=
disable_tqdm
):
for
chunk
in
tqdm
(
list
(
utils
.
chunks
(
reord
.
get_reordered
(),
self
.
REQ_CHUNK_SIZE
)),
disable
=
disable_tqdm
,
):
inps
=
[]
inps
=
[]
ctxlens
=
[]
ctxlens
=
[]
for
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
for
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
# max_length+1 because the API takes up to 2049 tokens, including the first context token
# max_length+1 because the API takes up to 2049 tokens, including the first context token
inp
=
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
):]
inp
=
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:]
# TODO: the logic is much simpler if we just look at the length of continuation tokens
# TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
(
self
.
max_length
+
1
))
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
(
self
.
max_length
+
1
)
)
inps
.
append
(
inp
)
inps
.
append
(
inp
)
ctxlens
.
append
(
ctxlen
)
ctxlens
.
append
(
ctxlen
)
...
@@ -137,11 +147,14 @@ class GPT3LM(BaseLM):
...
@@ -137,11 +147,14 @@ class GPT3LM(BaseLM):
engine
=
self
.
engine
,
engine
=
self
.
engine
,
prompt
=
inps
,
prompt
=
inps
,
echo
=
True
,
echo
=
True
,
max_tokens
=
0
,
temperature
=
0.
,
max_tokens
=
0
,
temperature
=
0.0
,
logprobs
=
10
,
logprobs
=
10
,
)
)
for
resp
,
ctxlen
,
(
cache_key
,
context_enc
,
continuation_enc
)
in
zip
(
response
.
choices
,
ctxlens
,
chunk
):
for
resp
,
ctxlen
,
(
cache_key
,
context_enc
,
continuation_enc
)
in
zip
(
response
.
choices
,
ctxlens
,
chunk
):
answer
=
get_result
(
resp
,
ctxlen
)
answer
=
get_result
(
resp
,
ctxlen
)
res
.
append
(
answer
)
res
.
append
(
answer
)
...
@@ -177,24 +190,26 @@ class GPT3LM(BaseLM):
...
@@ -177,24 +190,26 @@ class GPT3LM(BaseLM):
yield
ret
,
lastuntil
yield
ret
,
lastuntil
# todo: more intelligent batching for heterogeneous `until`
# todo: more intelligent batching for heterogeneous `until`
for
chunk
,
until
in
tqdm
(
list
(
sameuntil_chunks
(
reord
.
get_reordered
(),
self
.
REQ_CHUNK_SIZE
))):
for
chunk
,
until
in
tqdm
(
list
(
sameuntil_chunks
(
reord
.
get_reordered
(),
self
.
REQ_CHUNK_SIZE
))
):
inps
=
[]
inps
=
[]
for
context
,
_
in
chunk
:
for
context
,
_
in
chunk
:
context_enc
=
self
.
tok_encode
(
context
)
context_enc
=
self
.
tok_encode
(
context
)
inp
=
context_enc
[
-
(
self
.
max_length
-
self
.
max_gen_toks
):]
inp
=
context_enc
[
-
(
self
.
max_length
-
self
.
max_gen_toks
)
:]
inps
.
append
(
inp
)
inps
.
append
(
inp
)
response
=
oa_completion
(
response
=
oa_completion
(
engine
=
self
.
engine
,
engine
=
self
.
engine
,
prompt
=
inps
,
prompt
=
inps
,
max_tokens
=
self
.
max_gen_toks
,
max_tokens
=
self
.
max_gen_toks
,
temperature
=
0.
,
temperature
=
0.
0
,
logprobs
=
10
,
logprobs
=
10
,
stop
=
until
,
stop
=
until
,
)
)
for
resp
,
(
context
,
until_
)
in
zip
(
response
.
choices
,
chunk
):
for
resp
,
(
context
,
until_
)
in
zip
(
response
.
choices
,
chunk
):
s
=
resp
[
'
text
'
]
s
=
resp
[
"
text
"
]
for
term
in
until_
:
for
term
in
until_
:
s
=
s
.
split
(
term
)[
0
]
s
=
s
.
split
(
term
)[
0
]
...
...
lm_eval/tasks/__init__.py
View file @
121b7096
...
@@ -59,8 +59,8 @@ from . import storycloze
...
@@ -59,8 +59,8 @@ from . import storycloze
# 6 total
# 6 total
gpt3_translation_benchmarks
=
{
gpt3_translation_benchmarks
=
{
"wmt14"
:
[
'
en-fr
'
,
'
fr-en
'
],
# French
"wmt14"
:
[
"
en-fr
"
,
"
fr-en
"
],
# French
"wmt16"
:
[
'
en-ro
'
,
'
ro-en
'
,
'
de-en
'
,
'
en-de
'
],
# German, Romanian
"wmt16"
:
[
"
en-ro
"
,
"
ro-en
"
,
"
de-en
"
,
"
en-de
"
],
# German, Romanian
}
}
...
@@ -68,7 +68,7 @@ gpt3_translation_benchmarks = {
...
@@ -68,7 +68,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks
=
{
selected_translation_benchmarks
=
{
**
gpt3_translation_benchmarks
,
**
gpt3_translation_benchmarks
,
"wmt20"
:
sacrebleu
.
get_langpairs_for_testset
(
"wmt20"
),
"wmt20"
:
sacrebleu
.
get_langpairs_for_testset
(
"wmt20"
),
"iwslt17"
:
[
'
en-ar
'
,
'
ar-en
'
]
# Arabic
"iwslt17"
:
[
"
en-ar
"
,
"
ar-en
"
],
# Arabic
}
}
# 319 total
# 319 total
...
@@ -92,7 +92,7 @@ TASK_REGISTRY = {
...
@@ -92,7 +92,7 @@ TASK_REGISTRY = {
"rte"
:
glue
.
RTE
,
"rte"
:
glue
.
RTE
,
"qnli"
:
glue
.
QNLI
,
"qnli"
:
glue
.
QNLI
,
"qqp"
:
glue
.
QQP
,
"qqp"
:
glue
.
QQP
,
#"stsb": glue.STSB, # not implemented yet
#
"stsb": glue.STSB, # not implemented yet
"sst"
:
glue
.
SST
,
"sst"
:
glue
.
SST
,
"wnli"
:
glue
.
WNLI
,
"wnli"
:
glue
.
WNLI
,
# SuperGLUE
# SuperGLUE
...
@@ -103,34 +103,26 @@ TASK_REGISTRY = {
...
@@ -103,34 +103,26 @@ TASK_REGISTRY = {
"record"
:
superglue
.
ReCoRD
,
"record"
:
superglue
.
ReCoRD
,
"wic"
:
superglue
.
WordsInContext
,
"wic"
:
superglue
.
WordsInContext
,
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
# Order by benchmark/genre?
# Order by benchmark/genre?
"coqa"
:
coqa
.
CoQA
,
"coqa"
:
coqa
.
CoQA
,
"drop"
:
drop
.
DROP
,
"drop"
:
drop
.
DROP
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada_cloze"
:
lambada_cloze
.
LAMBADA_cloze
,
"lambada_cloze"
:
lambada_cloze
.
LAMBADA_cloze
,
# multilingual lambada
# multilingual lambada
**
lambada_multilingual
.
construct_tasks
(),
**
lambada_multilingual
.
construct_tasks
(),
"wikitext"
:
wikitext
.
WikiText
,
"wikitext"
:
wikitext
.
WikiText
,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa"
:
piqa
.
PiQA
,
"piqa"
:
piqa
.
PiQA
,
"prost"
:
prost
.
PROST
,
"prost"
:
prost
.
PROST
,
"mc_taco"
:
mc_taco
.
MCTACO
,
"mc_taco"
:
mc_taco
.
MCTACO
,
# Science related
# Science related
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"sciq"
:
sciq
.
SciQ
,
"qasper"
:
qasper
.
QASPER
,
"qasper"
:
qasper
.
QASPER
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"triviaqa"
:
triviaqa
.
TriviaQA
,
"triviaqa"
:
triviaqa
.
TriviaQA
,
"arc_easy"
:
arc
.
ARCEasy
,
"arc_easy"
:
arc
.
ARCEasy
,
"arc_challenge"
:
arc
.
ARCChallenge
,
"arc_challenge"
:
arc
.
ARCChallenge
,
...
@@ -152,21 +144,17 @@ TASK_REGISTRY = {
...
@@ -152,21 +144,17 @@ TASK_REGISTRY = {
"anli_r1"
:
anli
.
ANLIRound1
,
"anli_r1"
:
anli
.
ANLIRound1
,
"anli_r2"
:
anli
.
ANLIRound2
,
"anli_r2"
:
anli
.
ANLIRound2
,
"anli_r3"
:
anli
.
ANLIRound3
,
"anli_r3"
:
anli
.
ANLIRound3
,
"ethics_cm"
:
hendrycks_ethics
.
EthicsCM
,
"ethics_cm"
:
hendrycks_ethics
.
EthicsCM
,
"ethics_deontology"
:
hendrycks_ethics
.
EthicsDeontology
,
"ethics_deontology"
:
hendrycks_ethics
.
EthicsDeontology
,
"ethics_justice"
:
hendrycks_ethics
.
EthicsJustice
,
"ethics_justice"
:
hendrycks_ethics
.
EthicsJustice
,
"ethics_utilitarianism_original"
:
hendrycks_ethics
.
EthicsUtilitarianismOriginal
,
"ethics_utilitarianism_original"
:
hendrycks_ethics
.
EthicsUtilitarianismOriginal
,
"ethics_utilitarianism"
:
hendrycks_ethics
.
EthicsUtilitarianism
,
"ethics_utilitarianism"
:
hendrycks_ethics
.
EthicsUtilitarianism
,
"ethics_virtue"
:
hendrycks_ethics
.
EthicsVirtue
,
"ethics_virtue"
:
hendrycks_ethics
.
EthicsVirtue
,
"truthfulqa_mc"
:
truthfulqa
.
TruthfulQAMultipleChoice
,
"truthfulqa_mc"
:
truthfulqa
.
TruthfulQAMultipleChoice
,
"truthfulqa_gen"
:
truthfulqa
.
TruthfulQAGeneration
,
"truthfulqa_gen"
:
truthfulqa
.
TruthfulQAGeneration
,
# dialogue
# dialogue
"mutual"
:
mutual
.
MuTual
,
"mutual"
:
mutual
.
MuTual
,
"mutual_plus"
:
mutual
.
MuTualPlus
,
"mutual_plus"
:
mutual
.
MuTualPlus
,
# math
# math
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
...
@@ -177,7 +165,6 @@ TASK_REGISTRY = {
...
@@ -177,7 +165,6 @@ TASK_REGISTRY = {
"math_precalc"
:
hendrycks_math
.
MathPrecalculus
,
"math_precalc"
:
hendrycks_math
.
MathPrecalculus
,
"math_asdiv"
:
asdiv
.
Asdiv
,
"math_asdiv"
:
asdiv
.
Asdiv
,
"gsm8k"
:
gsm8k
.
GradeSchoolMath8K
,
"gsm8k"
:
gsm8k
.
GradeSchoolMath8K
,
# arithmetic
# arithmetic
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
"arithmetic_2ds"
:
arithmetic
.
Arithmetic2DMinus
,
"arithmetic_2ds"
:
arithmetic
.
Arithmetic2DMinus
,
...
@@ -191,22 +178,18 @@ TASK_REGISTRY = {
...
@@ -191,22 +178,18 @@ TASK_REGISTRY = {
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
# TODO Perhaps make these groups of tasks
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks)
# hendrycksTest (57 tasks)
**
hendrycks_test
.
create_all_tasks
(),
**
hendrycks_test
.
create_all_tasks
(),
# e.g. wmt14-fr-en
# e.g. wmt14-fr-en
**
translation
.
create_tasks_from_benchmarks
(
gpt3_translation_benchmarks
),
**
translation
.
create_tasks_from_benchmarks
(
gpt3_translation_benchmarks
),
# chef's selection, mostly wmt20
# chef's selection, mostly wmt20
**
translation
.
create_tasks_from_benchmarks
(
selected_translation_benchmarks
),
**
translation
.
create_tasks_from_benchmarks
(
selected_translation_benchmarks
),
# Word Scrambling and Manipulation Tasks
# Word Scrambling and Manipulation Tasks
"anagrams1"
:
unscramble
.
Anagrams1
,
"anagrams1"
:
unscramble
.
Anagrams1
,
"anagrams2"
:
unscramble
.
Anagrams2
,
"anagrams2"
:
unscramble
.
Anagrams2
,
"cycle_letters"
:
unscramble
.
CycleLetters
,
"cycle_letters"
:
unscramble
.
CycleLetters
,
"random_insertion"
:
unscramble
.
RandomInsertion
,
"random_insertion"
:
unscramble
.
RandomInsertion
,
"reversed_words"
:
unscramble
.
ReversedWords
,
"reversed_words"
:
unscramble
.
ReversedWords
,
# Pile
# Pile
"pile_arxiv"
:
pile
.
PileArxiv
,
"pile_arxiv"
:
pile
.
PileArxiv
,
"pile_books3"
:
pile
.
PileBooks3
,
"pile_books3"
:
pile
.
PileBooks3
,
...
@@ -230,7 +213,6 @@ TASK_REGISTRY = {
...
@@ -230,7 +213,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc"
:
pile
.
PileUbuntuIrc
,
"pile_ubuntu-irc"
:
pile
.
PileUbuntuIrc
,
"pile_wikipedia"
:
pile
.
PileWikipedia
,
"pile_wikipedia"
:
pile
.
PileWikipedia
,
"pile_youtubesubtitles"
:
pile
.
PileYoutubeSubtitles
,
"pile_youtubesubtitles"
:
pile
.
PileYoutubeSubtitles
,
# BLiMP
# BLiMP
"blimp_adjunct_island"
:
blimp
.
BlimpAdjunctIsland
,
"blimp_adjunct_island"
:
blimp
.
BlimpAdjunctIsland
,
"blimp_anaphor_gender_agreement"
:
blimp
.
BlimpAnaphorGenderAgreement
,
"blimp_anaphor_gender_agreement"
:
blimp
.
BlimpAnaphorGenderAgreement
,
...
@@ -299,7 +281,6 @@ TASK_REGISTRY = {
...
@@ -299,7 +281,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance"
:
blimp
.
BlimpWhVsThatNoGapLongDistance
,
"blimp_wh_vs_that_no_gap_long_distance"
:
blimp
.
BlimpWhVsThatNoGapLongDistance
,
"blimp_wh_vs_that_with_gap"
:
blimp
.
BlimpWhVsThatWithGap
,
"blimp_wh_vs_that_with_gap"
:
blimp
.
BlimpWhVsThatWithGap
,
"blimp_wh_vs_that_with_gap_long_distance"
:
blimp
.
BlimpWhVsThatWithGapLongDistance
,
"blimp_wh_vs_that_with_gap_long_distance"
:
blimp
.
BlimpWhVsThatWithGapLongDistance
,
# Requires manual download of data.
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
# "storycloze_2018": storycloze.StoryCloze2018,
...
@@ -325,17 +306,23 @@ def get_task_name_from_object(task_object):
...
@@ -325,17 +306,23 @@ def get_task_name_from_object(task_object):
return
name
return
name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
return
(
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
)
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
lm_eval
.
base
.
Task
]]):
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
lm_eval
.
base
.
Task
]]):
task_name_dict
=
{
task_name_dict
=
{
task_name
:
get_task
(
task_name
)()
task_name
:
get_task
(
task_name
)()
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
}
}
task_name_from_object_dict
=
{
task_name_from_object_dict
=
{
get_task_name_from_object
(
task_object
):
task_object
get_task_name_from_object
(
task_object
):
task_object
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
}
}
assert
set
(
task_name_dict
.
keys
()).
isdisjoint
(
set
(
task_name_from_object_dict
.
keys
()))
assert
set
(
task_name_dict
.
keys
()).
isdisjoint
(
set
(
task_name_from_object_dict
.
keys
()))
return
{
**
task_name_dict
,
**
task_name_from_object_dict
}
return
{
**
task_name_dict
,
**
task_name_from_object_dict
}
lm_eval/tasks/anli.py
View file @
121b7096
...
@@ -64,7 +64,12 @@ class ANLIBase(Task):
...
@@ -64,7 +64,12 @@ class ANLIBase(Task):
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# appended onto the question, with no "Answer:" or even a newline. Do we *really*
# appended onto the question, with no "Answer:" or even a newline. Do we *really*
# want to do it exactly as OA did?
# want to do it exactly as OA did?
return
doc
[
'premise'
]
+
'
\n
Question: '
+
doc
[
'hypothesis'
]
+
' True, False, or Neither?
\n
Answer:'
return
(
doc
[
"premise"
]
+
"
\n
Question: "
+
doc
[
"hypothesis"
]
+
" True, False, or Neither?
\n
Answer:"
)
def
should_decontaminate
(
self
):
def
should_decontaminate
(
self
):
return
True
return
True
...
@@ -76,10 +81,10 @@ class ANLIBase(Task):
...
@@ -76,10 +81,10 @@ class ANLIBase(Task):
# True = entailment
# True = entailment
# False = contradiction
# False = contradiction
# Neither = neutral
# Neither = neutral
return
" "
+
[
"True"
,
"Neither"
,
"False"
][
doc
[
'
label
'
]]
return
" "
+
[
"True"
,
"Neither"
,
"False"
][
doc
[
"
label
"
]]
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
:param doc:
:param doc:
...
@@ -106,9 +111,7 @@ class ANLIBase(Task):
...
@@ -106,9 +111,7 @@ class ANLIBase(Task):
"""
"""
gold
=
doc
[
"label"
]
gold
=
doc
[
"label"
]
pred
=
np
.
argmax
(
results
)
pred
=
np
.
argmax
(
results
)
return
{
return
{
"acc"
:
pred
==
gold
}
"acc"
:
pred
==
gold
}
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
"""
...
@@ -116,9 +119,7 @@ class ANLIBase(Task):
...
@@ -116,9 +119,7 @@ class ANLIBase(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
functions that aggregate a list of metrics
"""
"""
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
...
@@ -126,9 +127,7 @@ class ANLIBase(Task):
...
@@ -126,9 +127,7 @@ class ANLIBase(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
whether a higher value of the submetric is better
"""
"""
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
class
ANLIRound1
(
ANLIBase
):
class
ANLIRound1
(
ANLIBase
):
...
...
lm_eval/tasks/arithmetic.py
View file @
121b7096
...
@@ -67,10 +67,8 @@ class Arithmetic(Task):
...
@@ -67,10 +67,8 @@ class Arithmetic(Task):
return
is_prediction
return
is_prediction
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
is_prediction
,
=
results
(
is_prediction
,)
=
results
return
{
return
{
"acc"
:
is_prediction
}
"acc"
:
is_prediction
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
...
@@ -78,9 +76,7 @@ class Arithmetic(Task):
...
@@ -78,9 +76,7 @@ class Arithmetic(Task):
}
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
class
Arithmetic2DPlus
(
Arithmetic
):
class
Arithmetic2DPlus
(
Arithmetic
):
...
...
lm_eval/tasks/asdiv.py
View file @
121b7096
...
@@ -54,29 +54,28 @@ class Asdiv(Task):
...
@@ -54,29 +54,28 @@ class Asdiv(Task):
def
test_docs
(
self
):
def
test_docs
(
self
):
raise
NotImplementedError
(
"This dataset has no test docs"
)
raise
NotImplementedError
(
"This dataset has no test docs"
)
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
num_fewshot
==
0
,
"ASDiv is intended only for the zero-shot setting."
assert
num_fewshot
==
0
,
"ASDiv is intended only for the zero-shot setting."
return
super
().
fewshot_context
(
return
super
().
fewshot_context
(
doc
=
doc
,
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
)
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
# TODO: add solution-type
# TODO: add solution-type
return
doc
[
'
body
'
]
+
'
\n
'
+
'
Question:
'
+
doc
[
'
question
'
]
+
'
\n
'
+
'
Answer:
'
return
doc
[
"
body
"
]
+
"
\n
"
+
"
Question:
"
+
doc
[
"
question
"
]
+
"
\n
"
+
"
Answer:
"
def
should_decontaminate
(
self
):
def
should_decontaminate
(
self
):
return
True
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
'
body
'
]
+
" "
+
doc
[
'
question
'
]
return
doc
[
"
body
"
]
+
" "
+
doc
[
"
question
"
]
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
# TODO: add formula
# TODO: add formula
answer
=
doc
[
'
answer
'
].
split
(
'
(
'
)[
0
]
answer
=
doc
[
"
answer
"
].
split
(
"
(
"
)[
0
]
return
" "
+
answer
return
" "
+
answer
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
...
@@ -86,16 +85,10 @@ class Asdiv(Task):
...
@@ -86,16 +85,10 @@ class Asdiv(Task):
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
ll
,
is_greedy
=
results
ll
,
is_greedy
=
results
return
{
return
{
"acc"
:
int
(
is_greedy
)}
'acc'
:
int
(
is_greedy
)
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
mean
}
'acc'
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
'acc'
:
True
}
lm_eval/tasks/blimp.py
View file @
121b7096
...
@@ -50,9 +50,13 @@ class BlimpTask(Task):
...
@@ -50,9 +50,13 @@ class BlimpTask(Task):
# trained on this data.
# trained on this data.
return
self
.
dataset
[
"train"
]
return
self
.
dataset
[
"train"
]
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
num_fewshot
==
0
assert
num_fewshot
==
0
assert
rnd
is
not
None
,
"A `random.Random` generator argument must be provided to `rnd`"
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`"
assert
not
provide_description
,
(
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"a custom description to the context, supply the corresponding string via the "
...
@@ -60,7 +64,9 @@ class BlimpTask(Task):
...
@@ -60,7 +64,9 @@ class BlimpTask(Task):
)
)
if
provide_description
is
not
None
:
if
provide_description
is
not
None
:
# nudge people to not specify it at all
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return
""
return
""
...
...
lm_eval/tasks/cbt.py
View file @
121b7096
...
@@ -86,7 +86,9 @@ class CBTBase(Task):
...
@@ -86,7 +86,9 @@ class CBTBase(Task):
return
""
return
""
def
fewshot_examples
(
self
,
k
,
rnd
):
def
fewshot_examples
(
self
,
k
,
rnd
):
assert
k
==
0
,
f
"CBT is only implemented for the zero-shot setting. Given k=
{
k
}
."
assert
(
k
==
0
),
f
"CBT is only implemented for the zero-shot setting. Given k=
{
k
}
."
return
super
().
fewshot_examples
(
k
,
rnd
)
return
super
().
fewshot_examples
(
k
,
rnd
)
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
...
@@ -120,9 +122,7 @@ class CBTBase(Task):
...
@@ -120,9 +122,7 @@ class CBTBase(Task):
"""
"""
gold
=
doc
[
"options"
].
index
(
doc
[
"answer"
])
gold
=
doc
[
"options"
].
index
(
doc
[
"answer"
])
pred
=
np
.
argmax
(
results
)
pred
=
np
.
argmax
(
results
)
return
{
return
{
"acc"
:
pred
==
gold
}
"acc"
:
pred
==
gold
}
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
"""
...
@@ -130,9 +130,7 @@ class CBTBase(Task):
...
@@ -130,9 +130,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
functions that aggregate a list of metrics
"""
"""
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
...
@@ -140,9 +138,7 @@ class CBTBase(Task):
...
@@ -140,9 +138,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
whether a higher value of the submetric is better
"""
"""
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
class
CBTCN
(
CBTBase
):
class
CBTCN
(
CBTBase
):
...
...
lm_eval/tasks/coqa.py
View file @
121b7096
...
@@ -54,8 +54,10 @@ class CoQA(Task):
...
@@ -54,8 +54,10 @@ class CoQA(Task):
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
# and a question qi, the task is to predict the answer ai
doc_text
=
doc
[
"story"
]
+
'
\n\n
'
doc_text
=
doc
[
"story"
]
+
"
\n\n
"
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
][
"input_text"
],
doc
[
"answers"
][
"input_text"
][:
-
1
]):
# omit target answer ai
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
][
"input_text"
],
doc
[
"answers"
][
"input_text"
][:
-
1
]
):
# omit target answer ai
question
=
f
"Q:
{
q
}
\n\n
"
question
=
f
"Q:
{
q
}
\n\n
"
answer
=
f
"A:
{
a
}
\n\n
"
if
a
is
not
None
else
"A:"
answer
=
f
"A:
{
a
}
\n\n
"
if
a
is
not
None
else
"A:"
doc_text
+=
question
+
answer
doc_text
+=
question
+
answer
...
@@ -77,7 +79,9 @@ class CoQA(Task):
...
@@ -77,7 +79,9 @@ class CoQA(Task):
additional_answers
=
doc
.
get
(
"additional_answers"
)
additional_answers
=
doc
.
get
(
"additional_answers"
)
if
additional_answers
:
if
additional_answers
:
for
key
in
additional_answers
:
for
key
in
additional_answers
:
additional_answer_for_turn
=
additional_answers
[
key
][
"input_text"
][
turn_id
-
1
]
additional_answer_for_turn
=
additional_answers
[
key
][
"input_text"
][
turn_id
-
1
]
if
additional_answer_for_turn
.
lower
()
not
in
map
(
str
.
lower
,
answers
):
if
additional_answer_for_turn
.
lower
()
not
in
map
(
str
.
lower
,
answers
):
answers
.
append
(
additional_answer_for_turn
)
answers
.
append
(
additional_answer_for_turn
)
return
answers
return
answers
...
@@ -89,12 +93,12 @@ class CoQA(Task):
...
@@ -89,12 +93,12 @@ class CoQA(Task):
# ~ 2/3 of the CoQA answers are span-based
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
# (answers overlap with the passage ignoring punctuation and case mismatch)
if
raw_text
==
"unknown"
:
if
raw_text
==
"unknown"
:
return
'0'
return
"0"
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"yes"
:
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"yes"
:
return
'1'
return
"1"
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"no"
:
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"no"
:
return
'2'
return
"2"
return
'3'
# Not a yes/no question
return
"3"
# Not a yes/no question
@
staticmethod
@
staticmethod
def
compute_scores
(
gold_list
,
pred
):
def
compute_scores
(
gold_list
,
pred
):
...
@@ -104,25 +108,30 @@ class CoQA(Task):
...
@@ -104,25 +108,30 @@ class CoQA(Task):
em_sum
=
0.0
em_sum
=
0.0
if
len
(
gold_list
)
>
1
:
if
len
(
gold_list
)
>
1
:
for
i
in
range
(
len
(
gold_list
)):
for
i
in
range
(
len
(
gold_list
)):
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
# predictions compared against (n) golds and take maximum
# predictions compared against (n) golds and take maximum
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_answers
)
else
:
else
:
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_list
)
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_list
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_list
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_list
)
return
{
'em'
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
'f1'
:
f1_sum
/
max
(
1
,
len
(
gold_list
))}
return
{
"em"
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
"f1"
:
f1_sum
/
max
(
1
,
len
(
gold_list
)),
}
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
# Default to prediction of last turn.
# Default to prediction of last turn.
if
turnid
is
None
:
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
][
"input_text"
])
turnid
=
len
(
doc
[
"questions"
][
"input_text"
])
raw_text
=
doc
[
'
answers
'
][
"input_text"
][
turnid
-
1
]
raw_text
=
doc
[
"
answers
"
][
"input_text"
][
turnid
-
1
]
return
" "
+
raw_text
return
" "
+
raw_text
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
:param doc:
:param doc:
...
@@ -132,7 +141,7 @@ class CoQA(Task):
...
@@ -132,7 +141,7 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
"""
cont_request
=
rf
.
greedy_until
(
ctx
,
[
'
\n
Q:
'
])
cont_request
=
rf
.
greedy_until
(
ctx
,
[
"
\n
Q:
"
])
return
cont_request
return
cont_request
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
...
@@ -147,13 +156,13 @@ class CoQA(Task):
...
@@ -147,13 +156,13 @@ class CoQA(Task):
"""
"""
turn_id
=
len
(
doc
[
"questions"
][
"input_text"
])
turn_id
=
len
(
doc
[
"questions"
][
"input_text"
])
gold_list
=
self
.
get_answers
(
doc
,
turn_id
)
gold_list
=
self
.
get_answers
(
doc
,
turn_id
)
pred
=
results
[
0
].
strip
().
split
(
'
\n
'
)[
0
]
pred
=
results
[
0
].
strip
().
split
(
"
\n
"
)[
0
]
scores
=
self
.
compute_scores
(
gold_list
,
pred
)
scores
=
self
.
compute_scores
(
gold_list
,
pred
)
return
{
return
{
"f1"
:
scores
[
'
f1
'
],
"f1"
:
scores
[
"
f1
"
],
"em"
:
scores
[
'
em
'
],
"em"
:
scores
[
"
em
"
],
}
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
...
...
lm_eval/tasks/drop.py
View file @
121b7096
...
@@ -70,21 +70,26 @@ class DROP(Task):
...
@@ -70,21 +70,26 @@ class DROP(Task):
@
classmethod
@
classmethod
def
get_answers
(
cls
,
qa
):
def
get_answers
(
cls
,
qa
):
def
_flatten_validated_answers
(
validated_answers
):
def
_flatten_validated_answers
(
validated_answers
):
"""
Flattens a dict of lists of validated answers.
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
"""
vas
=
[]
vas
=
[]
for
i
in
range
(
len
(
validated_answers
[
"number"
])):
for
i
in
range
(
len
(
validated_answers
[
"number"
])):
vas
.
append
({
vas
.
append
(
{
"number"
:
validated_answers
[
"number"
][
i
],
"number"
:
validated_answers
[
"number"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
})
}
)
return
vas
return
vas
answers
=
[]
answers
=
[]
answers_set
=
set
()
answers_set
=
set
()
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
])
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
]
)
for
candidate
in
candidates
:
for
candidate
in
candidates
:
answer
=
cls
.
parse_answer
(
candidate
)
answer
=
cls
.
parse_answer
(
candidate
)
if
answer
in
answers_set
:
if
answer
in
answers_set
:
...
@@ -100,9 +105,11 @@ class DROP(Task):
...
@@ -100,9 +105,11 @@ class DROP(Task):
return
(
str
(
answer
[
"number"
]),)
return
(
str
(
answer
[
"number"
]),)
if
answer
[
"spans"
]
!=
[]:
if
answer
[
"spans"
]
!=
[]:
return
tuple
(
answer
[
"spans"
])
return
tuple
(
answer
[
"spans"
])
return
(
" "
.
join
([
answer
[
"date"
][
"day"
],
return
(
answer
[
"date"
][
"month"
],
" "
.
join
(
answer
[
"date"
][
"year"
]]).
strip
(),)
[
answer
[
"date"
][
"day"
],
answer
[
"date"
][
"month"
],
answer
[
"date"
][
"year"
]]
).
strip
(),
)
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
f
"Passage:
{
doc
[
'passage'
]
}
\n
Question:
{
doc
[
'question'
]
}
\n
Answer:"
return
f
"Passage:
{
doc
[
'passage'
]
}
\n
Question:
{
doc
[
'question'
]
}
\n
Answer:"
...
@@ -111,7 +118,7 @@ class DROP(Task):
...
@@ -111,7 +118,7 @@ class DROP(Task):
return
True
return
True
def
doc_to_decontamination_query
(
self
,
doc
):
def
doc_to_decontamination_query
(
self
,
doc
):
return
doc
[
'
passage
'
]
+
" "
+
doc
[
'
question
'
]
return
doc
[
"
passage
"
]
+
" "
+
doc
[
"
question
"
]
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
" "
+
", "
.
join
(
doc
[
"answers"
][
0
])
return
" "
+
", "
.
join
(
doc
[
"answers"
][
0
])
...
@@ -148,10 +155,7 @@ class DROP(Task):
...
@@ -148,10 +155,7 @@ class DROP(Task):
if
gold_answer
[
0
].
strip
():
if
gold_answer
[
0
].
strip
():
max_em
=
max
(
max_em
,
exact_match
)
max_em
=
max
(
max_em
,
exact_match
)
max_f1
=
max
(
max_f1
,
f1_score
)
max_f1
=
max
(
max_f1
,
f1_score
)
return
{
return
{
"em"
:
max_em
,
"f1"
:
max_f1
}
"em"
:
max_em
,
"f1"
:
max_f1
}
def
get_metrics
(
self
,
predicted
,
gold
):
def
get_metrics
(
self
,
predicted
,
gold
):
"""
"""
...
@@ -164,7 +168,9 @@ class DROP(Task):
...
@@ -164,7 +168,9 @@ class DROP(Task):
predicted_bags
=
self
.
_answer_to_bags
(
predicted
)
predicted_bags
=
self
.
_answer_to_bags
(
predicted
)
gold_bags
=
self
.
_answer_to_bags
(
gold
)
gold_bags
=
self
.
_answer_to_bags
(
gold
)
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
])
==
len
(
gold_bags
[
0
]):
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
]
)
==
len
(
gold_bags
[
0
]):
exact_match
=
1.0
exact_match
=
1.0
else
:
else
:
exact_match
=
0.0
exact_match
=
0.0
...
@@ -196,7 +202,9 @@ class DROP(Task):
...
@@ -196,7 +202,9 @@ class DROP(Task):
for
gold_index
,
gold_item
in
enumerate
(
gold
):
for
gold_index
,
gold_item
in
enumerate
(
gold
):
for
pred_index
,
pred_item
in
enumerate
(
predicted
):
for
pred_index
,
pred_item
in
enumerate
(
predicted
):
if
self
.
_match_numbers_if_present
(
gold_item
,
pred_item
):
if
self
.
_match_numbers_if_present
(
gold_item
,
pred_item
):
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
row_ind
,
col_ind
=
linear_sum_assignment
(
-
scores
)
row_ind
,
col_ind
=
linear_sum_assignment
(
-
scores
)
max_scores
=
np
.
zeros
([
max
(
len
(
gold
),
len
(
predicted
))])
max_scores
=
np
.
zeros
([
max
(
len
(
gold
),
len
(
predicted
))])
...
@@ -262,7 +270,11 @@ class DROP(Task):
...
@@ -262,7 +270,11 @@ class DROP(Task):
def
_normalize
(
self
,
answer
):
def
_normalize
(
self
,
answer
):
tokens
=
[
tokens
=
[
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))))
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))
)
)
for
token
in
self
.
_tokenize
(
answer
)
for
token
in
self
.
_tokenize
(
answer
)
]
]
tokens
=
[
token
for
token
in
tokens
if
token
.
strip
()]
tokens
=
[
token
for
token
in
tokens
if
token
.
strip
()]
...
@@ -275,10 +287,7 @@ class DROP(Task):
...
@@ -275,10 +287,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
functions that aggregate a list of metrics
"""
"""
return
{
return
{
"em"
:
mean
,
"f1"
:
mean
}
"em"
:
mean
,
"f1"
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
...
@@ -286,7 +295,4 @@ class DROP(Task):
...
@@ -286,7 +295,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
whether a higher value of the submetric is better
"""
"""
return
{
return
{
"em"
:
True
,
"f1"
:
True
}
"em"
:
True
,
"f1"
:
True
}
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment