Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
121b7096
Commit
121b7096
authored
May 02, 2022
by
Fabrizio Milo
Browse files
add pre-commit
parent
7a038118
Changes
120
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
439 additions
and
240 deletions
+439
-240
.coveragerc
.coveragerc
+2
-2
.flake8
.flake8
+5
-0
.github/workflows/pull_request.yml
.github/workflows/pull_request.yml
+13
-0
.gitignore
.gitignore
+1
-1
.pre-commit-config.yaml
.pre-commit-config.yaml
+42
-0
README.md
README.md
+4
-4
docs/decontamination.md
docs/decontamination.md
+2
-3
docs/task_guide.md
docs/task_guide.md
+9
-9
lm_eval/base.py
lm_eval/base.py
+160
-98
lm_eval/datasets/README.md
lm_eval/datasets/README.md
+2
-2
lm_eval/datasets/arithmetic/arithmetic.py
lm_eval/datasets/arithmetic/arithmetic.py
+69
-16
lm_eval/datasets/arithmetic/dataset_infos.json
lm_eval/datasets/arithmetic/dataset_infos.json
+1
-1
lm_eval/datasets/asdiv/asdiv.py
lm_eval/datasets/asdiv/asdiv.py
+9
-4
lm_eval/datasets/asdiv/dataset_infos.json
lm_eval/datasets/asdiv/dataset_infos.json
+1
-1
lm_eval/datasets/coqa/coqa.py
lm_eval/datasets/coqa/coqa.py
+51
-42
lm_eval/datasets/coqa/dataset_infos.json
lm_eval/datasets/coqa/dataset_infos.json
+1
-1
lm_eval/datasets/drop/dataset_infos.json
lm_eval/datasets/drop/dataset_infos.json
+1
-1
lm_eval/datasets/drop/drop.py
lm_eval/datasets/drop/drop.py
+56
-45
lm_eval/datasets/gsm8k/dataset_infos.json
lm_eval/datasets/gsm8k/dataset_infos.json
+1
-1
lm_eval/datasets/gsm8k/gsm8k.py
lm_eval/datasets/gsm8k/gsm8k.py
+9
-9
No files found.
.coveragerc
View file @
121b7096
.flake8
0 → 100644
View file @
121b7096
[flake8]
ignore = E203, E266, E501, W503, F403, F401, C901
max-line-length = 127
max-complexity = 10
select = B,C,E,F,W,T4,B9
.github/workflows/pull_request.yml
0 → 100644
View file @
121b7096
name
:
Pull Request
on
:
[
pull_request
]
jobs
:
pre-commit
:
runs-on
:
ubuntu-20.04
steps
:
-
uses
:
actions/checkout@v2
-
uses
:
actions/setup-python@v2
with
:
python-version
:
3.8
-
uses
:
pre-commit/action@v2.0.3
.gitignore
View file @
121b7096
.pre-commit-config.yaml
0 → 100644
View file @
121b7096
repos
:
-
repo
:
https://github.com/pre-commit/pre-commit-hooks
rev
:
v4.1.0
hooks
:
-
id
:
check-added-large-files
-
id
:
check-ast
-
id
:
check-byte-order-marker
-
id
:
check-case-conflict
-
id
:
check-json
-
id
:
check-merge-conflict
-
id
:
check-symlinks
-
id
:
check-yaml
-
id
:
destroyed-symlinks
-
id
:
detect-private-key
-
id
:
end-of-file-fixer
-
id
:
no-commit-to-branch
-
id
:
requirements-txt-fixer
-
id
:
trailing-whitespace
-
id
:
fix-byte-order-marker
exclude
:
docs/CNAME
-
id
:
fix-encoding-pragma
args
:
[
--remove
]
-
id
:
mixed-line-ending
args
:
[
--fix=lf
]
-
repo
:
https://gitlab.com/pycqa/flake8
rev
:
3.7.9
hooks
:
-
id
:
flake8
-
repo
:
https://github.com/psf/black
rev
:
22.3.0
hooks
:
-
id
:
black
language_version
:
python3.8
-
repo
:
https://github.com/codespell-project/codespell
rev
:
v2.1.0
hooks
:
-
id
:
codespell
args
:
[
"
--ignore-words-list=reord"
,
# Word used in error messages that need rewording
--check-filenames
,
--check-hidden
,
]
README.md
View file @
121b7096
docs/decontamination.md
View file @
121b7096
...
@@ -73,4 +73,3 @@ python -m scripts/clean_training_data/compress_and_package \
...
@@ -73,4 +73,3 @@ python -m scripts/clean_training_data/compress_and_package \
```
```
Congratulations, the final directory can now be passed to lm-evaulation-harness with the "--decontamination_ngrams_path" argument.
Congratulations, the final directory can now be passed to lm-evaulation-harness with the "--decontamination_ngrams_path" argument.
docs/task_guide.md
View file @
121b7096
lm_eval/base.py
View file @
121b7096
...
@@ -118,7 +118,6 @@ class LM(abc.ABC):
...
@@ -118,7 +118,6 @@ class LM(abc.ABC):
class
BaseLM
(
LM
):
class
BaseLM
(
LM
):
@
property
@
property
@
abstractmethod
@
abstractmethod
def
eot_token_id
(
self
):
def
eot_token_id
(
self
):
...
@@ -145,13 +144,16 @@ class BaseLM(LM):
...
@@ -145,13 +144,16 @@ class BaseLM(LM):
pass
pass
@
abstractmethod
@
abstractmethod
def
tok_encode
(
self
,
string
:
str
):
pass
def
tok_encode
(
self
,
string
:
str
):
pass
@
abstractmethod
@
abstractmethod
def
tok_decode
(
self
,
tokens
:
Iterable
[
int
]):
pass
def
tok_decode
(
self
,
tokens
:
Iterable
[
int
]):
pass
@
abstractmethod
@
abstractmethod
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
pass
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
pass
@
abstractmethod
@
abstractmethod
def
_model_call
(
self
,
inps
):
def
_model_call
(
self
,
inps
):
...
@@ -187,19 +189,26 @@ class BaseLM(LM):
...
@@ -187,19 +189,26 @@ class BaseLM(LM):
# TODO: automatic batch size detection for vectorization
# TODO: automatic batch size detection for vectorization
loglikelihoods
=
[]
loglikelihoods
=
[]
for
string
,
in
tqdm
(
requests
):
for
(
string
,)
in
tqdm
(
requests
):
rolling_token_windows
=
list
(
map
(
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
rolling_token_windows
=
list
(
map
(
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
context_len
=
1
,
)))
),
)
)
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that
# that
string_nll
=
self
.
_loglikelihood_tokens
(
rolling_token_windows
,
disable_tqdm
=
True
)
string_nll
=
self
.
_loglikelihood_tokens
(
rolling_token_windows
,
disable_tqdm
=
True
)
# discard is_greedy
# discard is_greedy
string_nll
=
[
x
[
0
]
for
x
in
string_nll
]
string_nll
=
[
x
[
0
]
for
x
in
string_nll
]
...
@@ -226,7 +235,9 @@ class BaseLM(LM):
...
@@ -226,7 +235,9 @@ class BaseLM(LM):
# TODO: automatic (variable) batch size detection for vectorization
# TODO: automatic (variable) batch size detection for vectorization
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
utils
.
chunks
(
tqdm
(
reord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
):
for
chunk
in
utils
.
chunks
(
tqdm
(
reord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
):
inps
=
[]
inps
=
[]
cont_toks_list
=
[]
cont_toks_list
=
[]
inplens
=
[]
inplens
=
[]
...
@@ -252,44 +263,60 @@ class BaseLM(LM):
...
@@ -252,44 +263,60 @@ class BaseLM(LM):
# when too long to fit in context, truncate from the left
# when too long to fit in context, truncate from the left
inp
=
torch
.
tensor
(
inp
=
torch
.
tensor
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
):][:
-
1
],
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
],
dtype
=
torch
.
long
dtype
=
torch
.
long
,
).
to
(
self
.
device
)
).
to
(
self
.
device
)
inplen
,
=
inp
.
shape
(
inplen
,
)
=
inp
.
shape
cont
=
continuation_enc
cont
=
continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length
=
padding_length
if
padding_length
is
not
None
else
inplen
padding_length
=
(
padding_length
if
padding_length
is
not
None
else
inplen
)
# pad length from seq to padding_length
# pad length from seq to padding_length
inp
=
torch
.
cat
([
inp
=
torch
.
cat
(
[
inp
,
# [seq]
inp
,
# [seq]
torch
.
zeros
(
padding_length
-
inplen
,
dtype
=
torch
.
long
).
to
(
inp
.
device
)
# [padding_length - seq]
torch
.
zeros
(
padding_length
-
inplen
,
dtype
=
torch
.
long
).
to
(
],
dim
=
0
)
inp
.
device
),
# [padding_length - seq]
],
dim
=
0
,
)
inps
.
append
(
inp
.
unsqueeze
(
0
))
# [1, padding_length]
inps
.
append
(
inp
.
unsqueeze
(
0
))
# [1, padding_length]
cont_toks_list
.
append
(
cont
)
cont_toks_list
.
append
(
cont
)
inplens
.
append
(
inplen
)
inplens
.
append
(
inplen
)
batched_inps
=
torch
.
cat
(
inps
,
dim
=
0
)
# [batch, padding_length
batched_inps
=
torch
.
cat
(
inps
,
dim
=
0
)
# [batch, padding_length
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
),
dim
=-
1
).
cpu
()
# [batch, padding_length, vocab]
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
),
dim
=-
1
).
cpu
()
# [batch, padding_length, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
\
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
in
zip
(
in
zip
(
chunk
,
multi_logits
,
inps
,
inplens
,
cont_toks_list
):
chunk
,
multi_logits
,
inps
,
inplens
,
cont_toks_list
):
# Slice to original seq length
# Slice to original seq length
contlen
=
len
(
cont_toks
)
contlen
=
len
(
cont_toks
)
logits
=
logits
[
inplen
-
contlen
:
inplen
].
unsqueeze
(
0
)
# [1, seq, vocab]
logits
=
logits
[
inplen
-
contlen
:
inplen
].
unsqueeze
(
0
)
# [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
# Check if per-token argmax is exactly equal to continuation
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
# [1, seq]
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
# [1, seq]
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
# Obtain log-probs at the corresponding continuation token indices
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
# Answer: (log prob, is-exact-match)
# Answer: (log prob, is-exact-match)
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
...
@@ -319,13 +346,17 @@ class BaseLM(LM):
...
@@ -319,13 +346,17 @@ class BaseLM(LM):
if
isinstance
(
until
,
str
):
if
isinstance
(
until
,
str
):
until
=
[
until
]
until
=
[
until
]
primary_until
,
=
self
.
tok_encode
(
until
[
0
])
(
primary_until
,
)
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
([
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]).
to
(
self
.
device
)
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]:])
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]
:])
for
term
in
until
:
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
s
=
s
.
split
(
term
)[
0
]
...
@@ -383,7 +414,7 @@ class Task(abc.ABC):
...
@@ -383,7 +414,7 @@ class Task(abc.ABC):
self
.
_fewshot_docs
=
None
self
.
_fewshot_docs
=
None
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
"""
Downloads and returns the task dataset.
"""Downloads and returns the task dataset.
Override this method to download the dataset from a custom API.
Override this method to download the dataset from a custom API.
:param data_dir: str
:param data_dir: str
...
@@ -412,7 +443,7 @@ class Task(abc.ABC):
...
@@ -412,7 +443,7 @@ class Task(abc.ABC):
name
=
self
.
DATASET_NAME
,
name
=
self
.
DATASET_NAME
,
data_dir
=
data_dir
,
data_dir
=
data_dir
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
download_mode
=
download_mode
download_mode
=
download_mode
,
)
)
def
should_decontaminate
(
self
):
def
should_decontaminate
(
self
):
...
@@ -473,8 +504,10 @@ class Task(abc.ABC):
...
@@ -473,8 +504,10 @@ class Task(abc.ABC):
return
rnd
.
sample
(
self
.
_training_docs
,
k
)
return
rnd
.
sample
(
self
.
_training_docs
,
k
)
def
doc_to_decontamination_query
(
self
,
doc
):
def
doc_to_decontamination_query
(
self
,
doc
):
print
(
"Override doc_to_decontamination_query with document specific decontamination query."
)
print
(
assert
(
False
)
"Override doc_to_decontamination_query with document specific decontamination query."
)
assert
False
@
abstractmethod
@
abstractmethod
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
...
@@ -486,7 +519,7 @@ class Task(abc.ABC):
...
@@ -486,7 +519,7 @@ class Task(abc.ABC):
@
abstractmethod
@
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
:param doc:
:param doc:
...
@@ -531,15 +564,19 @@ class Task(abc.ABC):
...
@@ -531,15 +564,19 @@ class Task(abc.ABC):
def
fewshot_description
(
self
):
def
fewshot_description
(
self
):
import
warnings
import
warnings
warnings
.
warn
(
warnings
.
warn
(
"`fewshot_description` will be removed in futures versions. Pass "
"`fewshot_description` will be removed in futures versions. Pass "
"any custom descriptions to the `evaluate` function instead."
,
"any custom descriptions to the `evaluate` function instead."
,
DeprecationWarning
)
DeprecationWarning
,
)
return
""
return
""
@
utils
.
positional_deprecated
@
utils
.
positional_deprecated
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
def
fewshot_context
(
""" Returns a fewshot context string that is made up of a prepended description
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
:param doc: str
...
@@ -556,7 +593,9 @@ class Task(abc.ABC):
...
@@ -556,7 +593,9 @@ class Task(abc.ABC):
:returns: str
:returns: str
The fewshot context.
The fewshot context.
"""
"""
assert
rnd
is
not
None
,
"A `random.Random` generator argument must be provided to `rnd`"
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`"
assert
not
provide_description
,
(
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"a custom description to the context, supply the corresponding string via the "
...
@@ -564,7 +603,9 @@ class Task(abc.ABC):
...
@@ -564,7 +603,9 @@ class Task(abc.ABC):
)
)
if
provide_description
is
not
None
:
if
provide_description
is
not
None
:
# nudge people to not specify it at all
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description
=
description
+
"
\n\n
"
if
description
else
""
description
=
description
+
"
\n\n
"
if
description
else
""
...
@@ -577,7 +618,9 @@ class Task(abc.ABC):
...
@@ -577,7 +618,9 @@ class Task(abc.ABC):
else
:
else
:
if
self
.
_fewshot_docs
is
None
:
if
self
.
_fewshot_docs
is
None
:
self
.
_fewshot_docs
=
list
(
self
.
_fewshot_docs
=
list
(
self
.
validation_docs
()
if
self
.
has_validation_docs
()
else
self
.
test_docs
()
self
.
validation_docs
()
if
self
.
has_validation_docs
()
else
self
.
test_docs
()
)
)
fewshotex
=
rnd
.
sample
(
self
.
_fewshot_docs
,
num_fewshot
+
1
)
fewshotex
=
rnd
.
sample
(
self
.
_fewshot_docs
,
num_fewshot
+
1
)
...
@@ -585,23 +628,27 @@ class Task(abc.ABC):
...
@@ -585,23 +628,27 @@ class Task(abc.ABC):
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
labeled_examples
=
"
\n\n
"
.
join
(
labeled_examples
=
(
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
]
"
\n\n
"
.
join
(
)
+
"
\n\n
"
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
]
)
+
"
\n\n
"
)
example
=
self
.
doc_to_text
(
doc
)
example
=
self
.
doc_to_text
(
doc
)
return
description
+
labeled_examples
+
example
return
description
+
labeled_examples
+
example
class
MultipleChoiceTask
(
Task
):
class
MultipleChoiceTask
(
Task
):
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
" "
+
doc
[
'
choices
'
][
doc
[
'
gold
'
]]
return
" "
+
doc
[
"
choices
"
][
doc
[
"
gold
"
]]
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
lls
=
[
lls
=
[
rf
.
loglikelihood
(
ctx
,
" {}"
.
format
(
choice
))[
0
]
rf
.
loglikelihood
(
ctx
,
" {}"
.
format
(
choice
))[
0
]
for
choice
in
doc
[
"choices"
]
for
choice
in
doc
[
'choices'
]
]
]
return
lls
return
lls
...
@@ -609,9 +656,9 @@ class MultipleChoiceTask(Task):
...
@@ -609,9 +656,9 @@ class MultipleChoiceTask(Task):
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
gold
=
doc
[
"gold"
]
gold
=
doc
[
"gold"
]
acc
=
1.
if
np
.
argmax
(
results
)
==
gold
else
0.
acc
=
1.
0
if
np
.
argmax
(
results
)
==
gold
else
0.
0
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
doc
[
"choices"
]])
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
doc
[
"choices"
]])
acc_norm
=
1.
if
np
.
argmax
(
results
/
completion_len
)
==
gold
else
0.
acc_norm
=
1.
0
if
np
.
argmax
(
results
/
completion_len
)
==
gold
else
0.
0
return
{
return
{
"acc"
:
acc
,
"acc"
:
acc
,
...
@@ -632,7 +679,6 @@ class MultipleChoiceTask(Task):
...
@@ -632,7 +679,6 @@ class MultipleChoiceTask(Task):
class
PerplexityTask
(
Task
,
abc
.
ABC
):
class
PerplexityTask
(
Task
,
abc
.
ABC
):
def
should_decontaminate
(
self
):
def
should_decontaminate
(
self
):
"""Whether this task supports decontamination against model training set."""
"""Whether this task supports decontamination against model training set."""
return
True
return
True
...
@@ -644,9 +690,15 @@ class PerplexityTask(Task, abc.ABC):
...
@@ -644,9 +690,15 @@ class PerplexityTask(Task, abc.ABC):
assert
k
==
0
assert
k
==
0
return
[]
return
[]
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
def
fewshot_context
(
assert
num_fewshot
==
0
,
"The number of fewshot examples must be 0 for perplexity tasks."
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
assert
rnd
is
not
None
,
"A `random.Random` generator argument must be provided to `rnd`."
):
assert
(
num_fewshot
==
0
),
"The number of fewshot examples must be 0 for perplexity tasks."
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`."
assert
not
provide_description
,
(
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"a custom description to the context, supply the corresponding string via the "
...
@@ -654,7 +706,9 @@ class PerplexityTask(Task, abc.ABC):
...
@@ -654,7 +706,9 @@ class PerplexityTask(Task, abc.ABC):
)
)
if
provide_description
is
not
None
:
if
provide_description
is
not
None
:
# nudge people to not specify it at all
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return
""
return
""
...
@@ -680,7 +734,7 @@ class PerplexityTask(Task, abc.ABC):
...
@@ -680,7 +734,7 @@ class PerplexityTask(Task, abc.ABC):
return
req
return
req
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
loglikelihood
,
=
results
(
loglikelihood
,
)
=
results
words
=
self
.
count_words
(
doc
)
words
=
self
.
count_words
(
doc
)
bytes_
=
self
.
count_bytes
(
doc
)
bytes_
=
self
.
count_bytes
(
doc
)
return
{
return
{
...
@@ -702,13 +756,13 @@ class PerplexityTask(Task, abc.ABC):
...
@@ -702,13 +756,13 @@ class PerplexityTask(Task, abc.ABC):
@
classmethod
@
classmethod
def
count_words
(
cls
,
doc
):
def
count_words
(
cls
,
doc
):
"""
Downstream tasks with custom word boundaries should override this!
"""
"""Downstream tasks with custom word boundaries should override this!"""
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
def
hash_args
(
attr
,
args
):
def
hash_args
(
attr
,
args
):
dat
=
json
.
dumps
([
attr
]
+
list
(
args
))
dat
=
json
.
dumps
([
attr
]
+
list
(
args
))
return
hashlib
.
sha256
(
dat
.
encode
(
'
utf-8
'
)).
hexdigest
()
return
hashlib
.
sha256
(
dat
.
encode
(
"
utf-8
"
)).
hexdigest
()
class
CacheHook
:
class
CacheHook
:
...
@@ -779,6 +833,7 @@ class CachingLM:
...
@@ -779,6 +833,7 @@ class CachingLM:
self
.
dbdict
.
commit
()
self
.
dbdict
.
commit
()
return
res
return
res
return
fn
return
fn
def
get_cache_hook
(
self
):
def
get_cache_hook
(
self
):
...
@@ -786,16 +841,18 @@ class CachingLM:
...
@@ -786,16 +841,18 @@ class CachingLM:
REQUEST_RETURN_LENGTHS
=
{
REQUEST_RETURN_LENGTHS
=
{
'
loglikelihood
'
:
2
,
"
loglikelihood
"
:
2
,
'
greedy_until
'
:
None
,
"
greedy_until
"
:
None
,
'
loglikelihood_rolling
'
:
None
,
"
loglikelihood_rolling
"
:
None
,
}
}
class
Request
:
class
Request
:
def
__init__
(
self
,
request_type
,
args
,
index
=
None
):
def
__init__
(
self
,
request_type
,
args
,
index
=
None
):
if
request_type
not
in
REQUEST_RETURN_LENGTHS
.
keys
():
if
request_type
not
in
REQUEST_RETURN_LENGTHS
.
keys
():
raise
NotImplementedError
(
'The request type {} is not implemented!'
.
format
(
request_type
))
raise
NotImplementedError
(
"The request type {} is not implemented!"
.
format
(
request_type
)
)
self
.
request_type
=
request_type
self
.
request_type
=
request_type
self
.
args
=
args
self
.
args
=
args
...
@@ -803,17 +860,21 @@ class Request:
...
@@ -803,17 +860,21 @@ class Request:
def
__iter__
(
self
):
def
__iter__
(
self
):
if
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]
is
None
:
if
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]
is
None
:
raise
IndexError
(
'
This request type does not return multiple arguments!
'
)
raise
IndexError
(
"
This request type does not return multiple arguments!
"
)
for
i
in
range
(
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]):
for
i
in
range
(
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]):
yield
Request
(
self
.
request_type
,
self
.
args
,
i
)
yield
Request
(
self
.
request_type
,
self
.
args
,
i
)
def
__getitem__
(
self
,
i
):
def
__getitem__
(
self
,
i
):
if
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]
is
None
:
if
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]
is
None
:
raise
IndexError
(
'
This request type does not return multiple arguments!
'
)
raise
IndexError
(
"
This request type does not return multiple arguments!
"
)
return
Request
(
self
.
request_type
,
self
.
args
,
i
)
return
Request
(
self
.
request_type
,
self
.
args
,
i
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
self
.
request_type
==
other
.
request_type
and
self
.
args
==
other
.
args
and
self
.
index
==
other
.
index
return
(
self
.
request_type
==
other
.
request_type
and
self
.
args
==
other
.
args
and
self
.
index
==
other
.
index
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
"Req_
{
self
.
request_type
}{
self
.
args
}
[
{
self
.
index
}
]
\n
"
return
f
"Req_
{
self
.
request_type
}{
self
.
args
}
[
{
self
.
index
}
]
\n
"
...
@@ -823,6 +884,7 @@ class RequestFactory:
...
@@ -823,6 +884,7 @@ class RequestFactory:
def
__getattr__
(
self
,
attr
):
def
__getattr__
(
self
,
attr
):
def
fn
(
*
args
):
def
fn
(
*
args
):
return
Request
(
attr
,
args
)
return
Request
(
attr
,
args
)
return
fn
return
fn
...
...
lm_eval/datasets/README.md
View file @
121b7096
lm_eval/datasets/arithmetic/arithmetic.py
View file @
121b7096
...
@@ -68,61 +68,111 @@ class Arithmetic(datasets.GeneratorBasedBuilder):
...
@@ -68,61 +68,111 @@ class Arithmetic(datasets.GeneratorBasedBuilder):
ArithmeticConfig
(
ArithmeticConfig
(
name
=
"arithmetic_2da"
,
name
=
"arithmetic_2da"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_addition.jsonl"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_addition.jsonl"
,
features
=
datasets
.
Features
({
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
)}),
features
=
datasets
.
Features
(
{
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
),
}
),
description
=
"2-digit addition"
,
description
=
"2-digit addition"
,
),
),
ArithmeticConfig
(
ArithmeticConfig
(
name
=
"arithmetic_2ds"
,
name
=
"arithmetic_2ds"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_subtraction.jsonl"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_subtraction.jsonl"
,
features
=
datasets
.
Features
({
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
)}),
features
=
datasets
.
Features
(
{
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
),
}
),
description
=
"2-digit subtraction"
,
description
=
"2-digit subtraction"
,
),
),
ArithmeticConfig
(
ArithmeticConfig
(
name
=
"arithmetic_3da"
,
name
=
"arithmetic_3da"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/three_digit_addition.jsonl"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/three_digit_addition.jsonl"
,
features
=
datasets
.
Features
({
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
)}),
features
=
datasets
.
Features
(
{
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
),
}
),
description
=
"3-digit addition"
,
description
=
"3-digit addition"
,
),
),
ArithmeticConfig
(
ArithmeticConfig
(
name
=
"arithmetic_3ds"
,
name
=
"arithmetic_3ds"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/three_digit_subtraction.jsonl"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/three_digit_subtraction.jsonl"
,
features
=
datasets
.
Features
({
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
)}),
features
=
datasets
.
Features
(
{
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
),
}
),
description
=
"3-digit subtraction"
,
description
=
"3-digit subtraction"
,
),
),
ArithmeticConfig
(
ArithmeticConfig
(
name
=
"arithmetic_4da"
,
name
=
"arithmetic_4da"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/four_digit_addition.jsonl"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/four_digit_addition.jsonl"
,
features
=
datasets
.
Features
({
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
)}),
features
=
datasets
.
Features
(
{
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
),
}
),
description
=
"4-digit addition"
,
description
=
"4-digit addition"
,
),
),
ArithmeticConfig
(
ArithmeticConfig
(
name
=
"arithmetic_4ds"
,
name
=
"arithmetic_4ds"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/four_digit_subtraction.jsonl"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/four_digit_subtraction.jsonl"
,
features
=
datasets
.
Features
({
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
)}),
features
=
datasets
.
Features
(
{
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
),
}
),
description
=
"4-digit subtraction"
,
description
=
"4-digit subtraction"
,
),
),
ArithmeticConfig
(
ArithmeticConfig
(
name
=
"arithmetic_5da"
,
name
=
"arithmetic_5da"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/five_digit_addition.jsonl"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/five_digit_addition.jsonl"
,
features
=
datasets
.
Features
({
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
)}),
features
=
datasets
.
Features
(
{
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
),
}
),
description
=
"5-digit addition"
,
description
=
"5-digit addition"
,
),
),
ArithmeticConfig
(
ArithmeticConfig
(
name
=
"arithmetic_5ds"
,
name
=
"arithmetic_5ds"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/five_digit_subtraction.jsonl"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/five_digit_subtraction.jsonl"
,
features
=
datasets
.
Features
({
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
)}),
features
=
datasets
.
Features
(
{
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
),
}
),
description
=
"5-digit subtraction"
,
description
=
"5-digit subtraction"
,
),
),
ArithmeticConfig
(
ArithmeticConfig
(
name
=
"arithmetic_2dm"
,
name
=
"arithmetic_2dm"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_multiplication.jsonl"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_multiplication.jsonl"
,
features
=
datasets
.
Features
({
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
)}),
features
=
datasets
.
Features
(
{
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
),
}
),
description
=
"2-digit multiplication"
,
description
=
"2-digit multiplication"
,
),
),
ArithmeticConfig
(
ArithmeticConfig
(
name
=
"arithmetic_1dc"
,
name
=
"arithmetic_1dc"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/single_digit_three_ops.jsonl"
,
url
=
"https://raw.githubusercontent.com/openai/gpt-3/master/data/single_digit_three_ops.jsonl"
,
features
=
datasets
.
Features
({
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
)}),
features
=
datasets
.
Features
(
{
"context"
:
datasets
.
Value
(
"string"
),
"completion"
:
datasets
.
Value
(
"string"
),
}
),
description
=
"Single digit 3 operations"
,
description
=
"Single digit 3 operations"
,
),
),
]
]
...
@@ -155,9 +205,12 @@ class Arithmetic(datasets.GeneratorBasedBuilder):
...
@@ -155,9 +205,12 @@ class Arithmetic(datasets.GeneratorBasedBuilder):
with
open
(
filepath
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
filepath
,
encoding
=
"utf-8"
)
as
f
:
for
key
,
row
in
enumerate
(
f
):
for
key
,
row
in
enumerate
(
f
):
data
=
json
.
loads
(
row
)
data
=
json
.
loads
(
row
)
context
=
data
[
'context'
].
strip
()
\
context
=
(
.
replace
(
'
\n\n
'
,
'
\n
'
)
\
data
[
"context"
]
.
replace
(
'Q:'
,
'Question:'
)
\
.
strip
()
.
replace
(
'A:'
,
'Answer:'
)
.
replace
(
"
\n\n
"
,
"
\n
"
)
completion
=
data
[
'completion'
]
.
replace
(
"Q:"
,
"Question:"
)
yield
key
,
{
'context'
:
context
,
'completion'
:
completion
}
.
replace
(
"A:"
,
"Answer:"
)
)
completion
=
data
[
"completion"
]
yield
key
,
{
"context"
:
context
,
"completion"
:
completion
}
lm_eval/datasets/arithmetic/dataset_infos.json
View file @
121b7096
lm_eval/datasets/asdiv/asdiv.py
View file @
121b7096
...
@@ -50,13 +50,16 @@ _URLS = "https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccf
...
@@ -50,13 +50,16 @@ _URLS = "https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccf
class
ASDiv
(
datasets
.
GeneratorBasedBuilder
):
class
ASDiv
(
datasets
.
GeneratorBasedBuilder
):
"""
ASDiv: A Diverse Corpus for Evaluating and Developing English Math Word Problem Solvers
"""
"""ASDiv: A Diverse Corpus for Evaluating and Developing English Math Word Problem Solvers"""
VERSION
=
datasets
.
Version
(
"0.0.1"
)
VERSION
=
datasets
.
Version
(
"0.0.1"
)
BUILDER_CONFIGS
=
[
BUILDER_CONFIGS
=
[
datasets
.
BuilderConfig
(
name
=
"asdiv"
,
version
=
VERSION
,
datasets
.
BuilderConfig
(
description
=
"A diverse corpus for evaluating and developing english math word problem solvers"
)
name
=
"asdiv"
,
version
=
VERSION
,
description
=
"A diverse corpus for evaluating and developing english math word problem solvers"
,
)
]
]
def
_info
(
self
):
def
_info
(
self
):
...
@@ -86,7 +89,9 @@ class ASDiv(datasets.GeneratorBasedBuilder):
...
@@ -86,7 +89,9 @@ class ASDiv(datasets.GeneratorBasedBuilder):
name
=
datasets
.
Split
.
VALIDATION
,
name
=
datasets
.
Split
.
VALIDATION
,
# These kwargs will be passed to _generate_examples
# These kwargs will be passed to _generate_examples
gen_kwargs
=
{
gen_kwargs
=
{
"filepath"
:
os
.
path
.
join
(
data_dir
,
base_filepath
,
"dataset"
,
"ASDiv.xml"
),
"filepath"
:
os
.
path
.
join
(
data_dir
,
base_filepath
,
"dataset"
,
"ASDiv.xml"
),
"split"
:
datasets
.
Split
.
VALIDATION
,
"split"
:
datasets
.
Split
.
VALIDATION
,
},
},
),
),
...
...
lm_eval/datasets/asdiv/dataset_infos.json
View file @
121b7096
lm_eval/datasets/coqa/coqa.py
View file @
121b7096
...
@@ -61,7 +61,7 @@ _EMPTY_ADDITIONAL_ANSWER = {
...
@@ -61,7 +61,7 @@ _EMPTY_ADDITIONAL_ANSWER = {
"span_end"
:
-
1
,
"span_end"
:
-
1
,
"span_text"
:
""
,
"span_text"
:
""
,
"input_text"
:
""
,
"input_text"
:
""
,
"turn_id"
:
-
1
"turn_id"
:
-
1
,
}
}
],
],
"1"
:
[
"1"
:
[
...
@@ -70,7 +70,7 @@ _EMPTY_ADDITIONAL_ANSWER = {
...
@@ -70,7 +70,7 @@ _EMPTY_ADDITIONAL_ANSWER = {
"span_end"
:
-
1
,
"span_end"
:
-
1
,
"span_text"
:
""
,
"span_text"
:
""
,
"input_text"
:
""
,
"input_text"
:
""
,
"turn_id"
:
-
1
"turn_id"
:
-
1
,
}
}
],
],
"2"
:
[
"2"
:
[
...
@@ -79,7 +79,7 @@ _EMPTY_ADDITIONAL_ANSWER = {
...
@@ -79,7 +79,7 @@ _EMPTY_ADDITIONAL_ANSWER = {
"span_end"
:
-
1
,
"span_end"
:
-
1
,
"span_text"
:
""
,
"span_text"
:
""
,
"input_text"
:
""
,
"input_text"
:
""
,
"turn_id"
:
-
1
"turn_id"
:
-
1
,
}
}
],
],
}
}
...
@@ -91,8 +91,9 @@ class Coqa(datasets.GeneratorBasedBuilder):
...
@@ -91,8 +91,9 @@ class Coqa(datasets.GeneratorBasedBuilder):
VERSION
=
datasets
.
Version
(
"0.0.1"
)
VERSION
=
datasets
.
Version
(
"0.0.1"
)
BUILDER_CONFIGS
=
[
BUILDER_CONFIGS
=
[
datasets
.
BuilderConfig
(
name
=
"coqa"
,
version
=
VERSION
,
datasets
.
BuilderConfig
(
description
=
"The CoQA dataset."
),
name
=
"coqa"
,
version
=
VERSION
,
description
=
"The CoQA dataset."
),
]
]
def
_info
(
self
):
def
_info
(
self
):
...
@@ -101,41 +102,52 @@ class Coqa(datasets.GeneratorBasedBuilder):
...
@@ -101,41 +102,52 @@ class Coqa(datasets.GeneratorBasedBuilder):
"id"
:
datasets
.
Value
(
"string"
),
"id"
:
datasets
.
Value
(
"string"
),
"source"
:
datasets
.
Value
(
"string"
),
"source"
:
datasets
.
Value
(
"string"
),
"story"
:
datasets
.
Value
(
"string"
),
"story"
:
datasets
.
Value
(
"string"
),
"questions"
:
datasets
.
features
.
Sequence
({
"questions"
:
datasets
.
features
.
Sequence
(
{
"input_text"
:
datasets
.
Value
(
"string"
),
"input_text"
:
datasets
.
Value
(
"string"
),
"turn_id"
:
datasets
.
Value
(
"int32"
),
"turn_id"
:
datasets
.
Value
(
"int32"
),
}),
}
"answers"
:
datasets
.
features
.
Sequence
({
),
"answers"
:
datasets
.
features
.
Sequence
(
{
"span_start"
:
datasets
.
Value
(
"int32"
),
"span_start"
:
datasets
.
Value
(
"int32"
),
"span_end"
:
datasets
.
Value
(
"int32"
),
"span_end"
:
datasets
.
Value
(
"int32"
),
"span_text"
:
datasets
.
Value
(
"string"
),
"span_text"
:
datasets
.
Value
(
"string"
),
"input_text"
:
datasets
.
Value
(
"string"
),
"input_text"
:
datasets
.
Value
(
"string"
),
"turn_id"
:
datasets
.
Value
(
"int32"
),
"turn_id"
:
datasets
.
Value
(
"int32"
),
}),
}
),
"additional_answers"
:
{
"additional_answers"
:
{
"0"
:
datasets
.
features
.
Sequence
({
"0"
:
datasets
.
features
.
Sequence
(
{
"span_start"
:
datasets
.
Value
(
"int32"
),
"span_start"
:
datasets
.
Value
(
"int32"
),
"span_end"
:
datasets
.
Value
(
"int32"
),
"span_end"
:
datasets
.
Value
(
"int32"
),
"span_text"
:
datasets
.
Value
(
"string"
),
"span_text"
:
datasets
.
Value
(
"string"
),
"input_text"
:
datasets
.
Value
(
"string"
),
"input_text"
:
datasets
.
Value
(
"string"
),
"turn_id"
:
datasets
.
Value
(
"int32"
),
"turn_id"
:
datasets
.
Value
(
"int32"
),
}),
}
"1"
:
datasets
.
features
.
Sequence
({
),
"1"
:
datasets
.
features
.
Sequence
(
{
"span_start"
:
datasets
.
Value
(
"int32"
),
"span_start"
:
datasets
.
Value
(
"int32"
),
"span_end"
:
datasets
.
Value
(
"int32"
),
"span_end"
:
datasets
.
Value
(
"int32"
),
"span_text"
:
datasets
.
Value
(
"string"
),
"span_text"
:
datasets
.
Value
(
"string"
),
"input_text"
:
datasets
.
Value
(
"string"
),
"input_text"
:
datasets
.
Value
(
"string"
),
"turn_id"
:
datasets
.
Value
(
"int32"
),
"turn_id"
:
datasets
.
Value
(
"int32"
),
}),
}
"2"
:
datasets
.
features
.
Sequence
({
),
"2"
:
datasets
.
features
.
Sequence
(
{
"span_start"
:
datasets
.
Value
(
"int32"
),
"span_start"
:
datasets
.
Value
(
"int32"
),
"span_end"
:
datasets
.
Value
(
"int32"
),
"span_end"
:
datasets
.
Value
(
"int32"
),
"span_text"
:
datasets
.
Value
(
"string"
),
"span_text"
:
datasets
.
Value
(
"string"
),
"input_text"
:
datasets
.
Value
(
"string"
),
"input_text"
:
datasets
.
Value
(
"string"
),
"turn_id"
:
datasets
.
Value
(
"int32"
),
"turn_id"
:
datasets
.
Value
(
"int32"
),
}),
}
}
})
),
},
}
)
return
datasets
.
DatasetInfo
(
return
datasets
.
DatasetInfo
(
description
=
_DESCRIPTION
,
description
=
_DESCRIPTION
,
features
=
features
,
features
=
features
,
...
@@ -175,10 +187,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
...
@@ -175,10 +187,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
source
=
row
[
"source"
]
source
=
row
[
"source"
]
story
=
row
[
"story"
]
story
=
row
[
"story"
]
questions
=
[
questions
=
[
{
{
"input_text"
:
q
[
"input_text"
],
"turn_id"
:
q
[
"turn_id"
]}
"input_text"
:
q
[
"input_text"
],
"turn_id"
:
q
[
"turn_id"
]
}
for
q
in
row
[
"questions"
]
for
q
in
row
[
"questions"
]
]
]
answers
=
[
answers
=
[
...
@@ -187,7 +196,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
...
@@ -187,7 +196,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
"span_end"
:
a
[
"span_end"
],
"span_end"
:
a
[
"span_end"
],
"span_text"
:
a
[
"span_text"
],
"span_text"
:
a
[
"span_text"
],
"input_text"
:
a
[
"input_text"
],
"input_text"
:
a
[
"input_text"
],
"turn_id"
:
a
[
"turn_id"
]
"turn_id"
:
a
[
"turn_id"
]
,
}
}
for
a
in
row
[
"answers"
]
for
a
in
row
[
"answers"
]
]
]
...
@@ -201,7 +210,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
...
@@ -201,7 +210,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
"span_end"
:
a0
[
"span_end"
],
"span_end"
:
a0
[
"span_end"
],
"span_text"
:
a0
[
"span_text"
],
"span_text"
:
a0
[
"span_text"
],
"input_text"
:
a0
[
"input_text"
],
"input_text"
:
a0
[
"input_text"
],
"turn_id"
:
a0
[
"turn_id"
]
"turn_id"
:
a0
[
"turn_id"
]
,
}
}
for
a0
in
row
[
"additional_answers"
][
"0"
]
for
a0
in
row
[
"additional_answers"
][
"0"
]
],
],
...
@@ -211,7 +220,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
...
@@ -211,7 +220,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
"span_end"
:
a1
[
"span_end"
],
"span_end"
:
a1
[
"span_end"
],
"span_text"
:
a1
[
"span_text"
],
"span_text"
:
a1
[
"span_text"
],
"input_text"
:
a1
[
"input_text"
],
"input_text"
:
a1
[
"input_text"
],
"turn_id"
:
a1
[
"turn_id"
]
"turn_id"
:
a1
[
"turn_id"
]
,
}
}
for
a1
in
row
[
"additional_answers"
][
"1"
]
for
a1
in
row
[
"additional_answers"
][
"1"
]
],
],
...
@@ -221,7 +230,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
...
@@ -221,7 +230,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
"span_end"
:
a2
[
"span_end"
],
"span_end"
:
a2
[
"span_end"
],
"span_text"
:
a2
[
"span_text"
],
"span_text"
:
a2
[
"span_text"
],
"input_text"
:
a2
[
"input_text"
],
"input_text"
:
a2
[
"input_text"
],
"turn_id"
:
a2
[
"turn_id"
]
"turn_id"
:
a2
[
"turn_id"
]
,
}
}
for
a2
in
row
[
"additional_answers"
][
"2"
]
for
a2
in
row
[
"additional_answers"
][
"2"
]
],
],
...
@@ -232,5 +241,5 @@ class Coqa(datasets.GeneratorBasedBuilder):
...
@@ -232,5 +241,5 @@ class Coqa(datasets.GeneratorBasedBuilder):
"source"
:
source
,
"source"
:
source
,
"questions"
:
questions
,
"questions"
:
questions
,
"answers"
:
answers
,
"answers"
:
answers
,
"additional_answers"
:
additional_answers
"additional_answers"
:
additional_answers
,
}
}
lm_eval/datasets/coqa/dataset_infos.json
View file @
121b7096
lm_eval/datasets/drop/dataset_infos.json
View file @
121b7096
lm_eval/datasets/drop/drop.py
View file @
121b7096
...
@@ -50,7 +50,8 @@ _URLS = {
...
@@ -50,7 +50,8 @@ _URLS = {
"drop"
:
"https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip"
,
"drop"
:
"https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip"
,
}
}
_EMPTY_VALIDATED_ANSWER
=
[{
_EMPTY_VALIDATED_ANSWER
=
[
{
"number"
:
""
,
"number"
:
""
,
"date"
:
{
"date"
:
{
"day"
:
""
,
"day"
:
""
,
...
@@ -59,8 +60,9 @@ _EMPTY_VALIDATED_ANSWER = [{
...
@@ -59,8 +60,9 @@ _EMPTY_VALIDATED_ANSWER = [{
},
},
"spans"
:
[],
"spans"
:
[],
"worker_id"
:
""
,
"worker_id"
:
""
,
"hit_id"
:
""
"hit_id"
:
""
,
}]
}
]
class
Drop
(
datasets
.
GeneratorBasedBuilder
):
class
Drop
(
datasets
.
GeneratorBasedBuilder
):
...
@@ -69,12 +71,14 @@ class Drop(datasets.GeneratorBasedBuilder):
...
@@ -69,12 +71,14 @@ class Drop(datasets.GeneratorBasedBuilder):
VERSION
=
datasets
.
Version
(
"0.0.1"
)
VERSION
=
datasets
.
Version
(
"0.0.1"
)
BUILDER_CONFIGS
=
[
BUILDER_CONFIGS
=
[
datasets
.
BuilderConfig
(
name
=
"drop"
,
version
=
VERSION
,
datasets
.
BuilderConfig
(
description
=
"The DROP dataset."
),
name
=
"drop"
,
version
=
VERSION
,
description
=
"The DROP dataset."
),
]
]
def
_info
(
self
):
def
_info
(
self
):
features
=
datasets
.
Features
({
features
=
datasets
.
Features
(
{
"section_id"
:
datasets
.
Value
(
"string"
),
"section_id"
:
datasets
.
Value
(
"string"
),
"passage"
:
datasets
.
Value
(
"string"
),
"passage"
:
datasets
.
Value
(
"string"
),
"question"
:
datasets
.
Value
(
"string"
),
"question"
:
datasets
.
Value
(
"string"
),
...
@@ -90,7 +94,8 @@ class Drop(datasets.GeneratorBasedBuilder):
...
@@ -90,7 +94,8 @@ class Drop(datasets.GeneratorBasedBuilder):
"worker_id"
:
datasets
.
Value
(
"string"
),
"worker_id"
:
datasets
.
Value
(
"string"
),
"hit_id"
:
datasets
.
Value
(
"string"
),
"hit_id"
:
datasets
.
Value
(
"string"
),
},
},
"validated_answers"
:
datasets
.
features
.
Sequence
({
"validated_answers"
:
datasets
.
features
.
Sequence
(
{
"number"
:
datasets
.
Value
(
"string"
),
"number"
:
datasets
.
Value
(
"string"
),
"date"
:
{
"date"
:
{
"day"
:
datasets
.
Value
(
"string"
),
"day"
:
datasets
.
Value
(
"string"
),
...
@@ -100,8 +105,10 @@ class Drop(datasets.GeneratorBasedBuilder):
...
@@ -100,8 +105,10 @@ class Drop(datasets.GeneratorBasedBuilder):
"spans"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)),
"spans"
:
datasets
.
features
.
Sequence
(
datasets
.
Value
(
"string"
)),
"worker_id"
:
datasets
.
Value
(
"string"
),
"worker_id"
:
datasets
.
Value
(
"string"
),
"hit_id"
:
datasets
.
Value
(
"string"
),
"hit_id"
:
datasets
.
Value
(
"string"
),
}),
}
})
),
}
)
return
datasets
.
DatasetInfo
(
return
datasets
.
DatasetInfo
(
description
=
_DESCRIPTION
,
description
=
_DESCRIPTION
,
features
=
features
,
features
=
features
,
...
@@ -118,7 +125,9 @@ class Drop(datasets.GeneratorBasedBuilder):
...
@@ -118,7 +125,9 @@ class Drop(datasets.GeneratorBasedBuilder):
name
=
datasets
.
Split
.
TRAIN
,
name
=
datasets
.
Split
.
TRAIN
,
# These kwargs will be passed to _generate_examples
# These kwargs will be passed to _generate_examples
gen_kwargs
=
{
gen_kwargs
=
{
"filepath"
:
os
.
path
.
join
(
data_dir
,
"drop_dataset"
,
"drop_dataset_train.json"
),
"filepath"
:
os
.
path
.
join
(
data_dir
,
"drop_dataset"
,
"drop_dataset_train.json"
),
"split"
:
"train"
,
"split"
:
"train"
,
},
},
),
),
...
@@ -126,7 +135,9 @@ class Drop(datasets.GeneratorBasedBuilder):
...
@@ -126,7 +135,9 @@ class Drop(datasets.GeneratorBasedBuilder):
name
=
datasets
.
Split
.
VALIDATION
,
name
=
datasets
.
Split
.
VALIDATION
,
# These kwargs will be passed to _generate_examples
# These kwargs will be passed to _generate_examples
gen_kwargs
=
{
gen_kwargs
=
{
"filepath"
:
os
.
path
.
join
(
data_dir
,
"drop_dataset"
,
"drop_dataset_dev.json"
),
"filepath"
:
os
.
path
.
join
(
data_dir
,
"drop_dataset"
,
"drop_dataset_dev.json"
),
"split"
:
"validation"
,
"split"
:
"validation"
,
},
},
),
),
...
...
lm_eval/datasets/gsm8k/dataset_infos.json
View file @
121b7096
lm_eval/datasets/gsm8k/gsm8k.py
View file @
121b7096
...
@@ -56,8 +56,11 @@ class GSM8K(datasets.GeneratorBasedBuilder):
...
@@ -56,8 +56,11 @@ class GSM8K(datasets.GeneratorBasedBuilder):
VERSION
=
datasets
.
Version
(
"0.0.1"
)
VERSION
=
datasets
.
Version
(
"0.0.1"
)
BUILDER_CONFIGS
=
[
BUILDER_CONFIGS
=
[
datasets
.
BuilderConfig
(
name
=
"gsm8k"
,
version
=
VERSION
,
datasets
.
BuilderConfig
(
description
=
"The Grade School Math 8k dataset."
),
name
=
"gsm8k"
,
version
=
VERSION
,
description
=
"The Grade School Math 8k dataset."
,
),
]
]
def
_info
(
self
):
def
_info
(
self
):
...
@@ -90,10 +93,7 @@ class GSM8K(datasets.GeneratorBasedBuilder):
...
@@ -90,10 +93,7 @@ class GSM8K(datasets.GeneratorBasedBuilder):
datasets
.
SplitGenerator
(
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
name
=
datasets
.
Split
.
TEST
,
# These kwargs will be passed to _generate_examples
# These kwargs will be passed to _generate_examples
gen_kwargs
=
{
gen_kwargs
=
{
"filepath"
:
data_dir
[
"test"
],
"split"
:
"test"
},
"filepath"
:
data_dir
[
"test"
],
"split"
:
"test"
},
),
),
]
]
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment