Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
5fbc3f86
Commit
5fbc3f86
authored
Jul 14, 2023
by
haileyschoelkopf
Browse files
fix random seed issue, log_samples optional
parent
09a71562
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
43 deletions
+55
-43
lm_eval/api/task.py
lm_eval/api/task.py
+12
-15
lm_eval/evaluator.py
lm_eval/evaluator.py
+26
-15
main.py
main.py
+17
-13
No files found.
lm_eval/api/task.py
View file @
5fbc3f86
...
@@ -8,6 +8,7 @@ import evaluate
...
@@ -8,6 +8,7 @@ import evaluate
import
random
import
random
import
itertools
import
itertools
import
functools
import
functools
from
tqdm
import
tqdm
import
datasets
import
datasets
import
numpy
as
np
import
numpy
as
np
...
@@ -217,8 +218,8 @@ class Task(abc.ABC):
...
@@ -217,8 +218,8 @@ class Task(abc.ABC):
self
.
_filters
.
append
(
filter_pipeline
)
self
.
_filters
.
append
(
filter_pipeline
)
self
.
sampler
=
samplers
.
Sampler
(
self
.
sampler
=
samplers
.
Sampler
(
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
()
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
(
1234
)
)
# TODO: pass the correct docs in here
)
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
"""Downloads and returns the task dataset.
"""Downloads and returns the task dataset.
...
@@ -366,13 +367,18 @@ class Task(abc.ABC):
...
@@ -366,13 +367,18 @@ class Task(abc.ABC):
False
False
),
f
"Task dataset (path=
{
self
.
DATASET_PATH
}
, name=
{
self
.
DATASET_NAME
}
) must have valid or test docs!"
),
f
"Task dataset (path=
{
self
.
DATASET_PATH
}
, name=
{
self
.
DATASET_NAME
}
) must have valid or test docs!"
eval_logger
.
info
(
f
"Building contexts for task '
{
self
.
_config
.
task
}
' on rank
{
rank
}
..."
)
instances
=
[]
instances
=
[]
for
doc_id
,
doc
in
utils
.
create_iterator
(
for
doc_id
,
doc
in
utils
.
create_iterator
(
enumerate
(
docs
),
rank
,
world_size
,
limit
enumerate
(
docs
),
rank
,
world_size
,
limit
):
):
# sample fewshot context #TODO: need to offset doc_id by rank now!
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx
=
self
.
fewshot_context
(
fewshot_ctx
=
self
.
fewshot_context
(
doc
,
self
.
_config
.
num_fewshot
,
rnd
=
random
.
Random
()
doc
,
self
.
_config
.
num_fewshot
,
)
)
# TODO: we should override self._config.repeats if doing greedy gen so users don't waste time+compute
# TODO: we should override self._config.repeats if doing greedy gen so users don't waste time+compute
...
@@ -453,7 +459,7 @@ class Task(abc.ABC):
...
@@ -453,7 +459,7 @@ class Task(abc.ABC):
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
@
utils
.
positional_deprecated
@
utils
.
positional_deprecated
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
rnd
=
None
):
def
fewshot_context
(
self
,
doc
,
num_fewshot
):
"""Returns a fewshot context string that is made up of a prepended description
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
...
@@ -461,15 +467,9 @@ class Task(abc.ABC):
...
@@ -461,15 +467,9 @@ class Task(abc.ABC):
The document as returned from training_docs, validation_docs, or test_docs.
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
The number of fewshot examples to provide in the returned context string.
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:returns: str
:returns: str
The fewshot context.
The fewshot context.
"""
"""
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`"
if
num_fewshot
==
0
:
if
num_fewshot
==
0
:
# always prepend the (possibly empty) task description
# always prepend the (possibly empty) task description
...
@@ -625,7 +625,7 @@ class ConfigurableTask(Task):
...
@@ -625,7 +625,7 @@ class ConfigurableTask(Task):
if
self
.
fewshot_docs
()
is
not
None
:
if
self
.
fewshot_docs
()
is
not
None
:
self
.
sampler
=
samplers
.
Sampler
(
self
.
sampler
=
samplers
.
Sampler
(
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
()
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
(
1234
)
)
)
def
download
(
self
,
dataset_kwargs
=
None
):
def
download
(
self
,
dataset_kwargs
=
None
):
...
@@ -1004,13 +1004,10 @@ class PerplexityTask(Task):
...
@@ -1004,13 +1004,10 @@ class PerplexityTask(Task):
assert
k
==
0
assert
k
==
0
return
[]
return
[]
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
rnd
=
None
):
def
fewshot_context
(
self
,
doc
,
num_fewshot
):
assert
(
assert
(
num_fewshot
==
0
num_fewshot
==
0
),
"The number of fewshot examples must be 0 for perplexity tasks."
),
"The number of fewshot examples must be 0 for perplexity tasks."
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`."
return
""
return
""
...
...
lm_eval/evaluator.py
View file @
5fbc3f86
...
@@ -45,6 +45,7 @@ def simple_evaluate(
...
@@ -45,6 +45,7 @@ def simple_evaluate(
check_integrity
=
False
,
check_integrity
=
False
,
decontamination_ngrams_path
=
None
,
decontamination_ngrams_path
=
None
,
write_out
=
False
,
write_out
=
False
,
log_samples
=
True
,
):
):
"""Instantiate and evaluate a model on a list of tasks.
"""Instantiate and evaluate a model on a list of tasks.
...
@@ -72,12 +73,17 @@ def simple_evaluate(
...
@@ -72,12 +73,17 @@ def simple_evaluate(
:param check_integrity: bool
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
Whether to run the relevant part of the test suite for the tasks
:param write_out: bool
:param write_out: bool
If True, write details about prompts and logits to json for all tasks
If True, write out an example document and model input for checking task integrity
:param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:return
:return
Dictionary of results
Dictionary of results
"""
"""
random
.
seed
(
1234
)
random
.
seed
(
0
)
np
.
random
.
seed
(
1234
)
np
.
random
.
seed
(
1234
)
torch
.
manual_seed
(
1234
)
# TODO: this may affect training runs that are run with evaluation mid-run.
assert
tasks
!=
[],
"No tasks specified"
assert
tasks
!=
[],
"No tasks specified"
...
@@ -118,6 +124,7 @@ def simple_evaluate(
...
@@ -118,6 +124,7 @@ def simple_evaluate(
bootstrap_iters
=
bootstrap_iters
,
bootstrap_iters
=
bootstrap_iters
,
decontamination_ngrams_path
=
decontamination_ngrams_path
,
decontamination_ngrams_path
=
decontamination_ngrams_path
,
write_out
=
write_out
,
write_out
=
write_out
,
log_samples
=
log_samples
,
)
)
if
lm
.
rank
==
0
:
if
lm
.
rank
==
0
:
...
@@ -154,6 +161,7 @@ def evaluate(
...
@@ -154,6 +161,7 @@ def evaluate(
bootstrap_iters
=
100000
,
bootstrap_iters
=
100000
,
decontamination_ngrams_path
=
None
,
decontamination_ngrams_path
=
None
,
write_out
=
False
,
write_out
=
False
,
log_samples
=
True
,
):
):
"""Instantiate and evaluate a model on a list of tasks.
"""Instantiate and evaluate a model on a list of tasks.
...
@@ -168,7 +176,9 @@ def evaluate(
...
@@ -168,7 +176,9 @@ def evaluate(
:param bootstrap_iters:
:param bootstrap_iters:
Number of iterations for bootstrap statistics
Number of iterations for bootstrap statistics
:param write_out: bool
:param write_out: bool
If True, write all prompts, logits and metrics to json for offline analysis
If True, write out an example document and model input for checking task integrity
:param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:return
:return
Dictionary of results
Dictionary of results
"""
"""
...
@@ -282,17 +292,18 @@ def evaluate(
...
@@ -282,17 +292,18 @@ def evaluate(
metrics
=
task
.
process_results
(
metrics
=
task
.
process_results
(
doc
,
[
req
.
filtered_resps
[
key
]
for
req
in
requests
]
doc
,
[
req
.
filtered_resps
[
key
]
for
req
in
requests
]
)
)
target
=
task
.
doc_to_target
(
doc
)
if
log_samples
:
example
=
{
target
=
task
.
doc_to_target
(
doc
)
"doc_id"
:
doc_id
,
example
=
{
"doc"
:
doc
,
"doc_id"
:
doc_id
,
"target"
:
target
,
"doc"
:
doc
,
"arguments"
:
[
req
.
args
for
req
in
requests
],
"target"
:
target
,
"resps"
:
[
req
.
resps
for
req
in
requests
],
"arguments"
:
[
req
.
args
for
req
in
requests
],
"filtered_resps"
:
[
req
.
filtered_resps
[
key
]
for
req
in
requests
],
"resps"
:
[
req
.
resps
for
req
in
requests
],
}
"filtered_resps"
:
[
req
.
filtered_resps
[
key
]
for
req
in
requests
],
example
.
update
(
metrics
)
}
samples
[
task_name
].
append
(
example
)
example
.
update
(
metrics
)
samples
[
task_name
].
append
(
example
)
for
metric
,
value
in
metrics
.
items
():
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_name
,
key
,
metric
)].
append
(
value
)
vals
[(
task_name
,
key
,
metric
)].
append
(
value
)
...
@@ -366,7 +377,7 @@ def evaluate(
...
@@ -366,7 +377,7 @@ def evaluate(
"results"
:
dict
(
results
),
"results"
:
dict
(
results
),
"configs"
:
dict
(
configs
),
"configs"
:
dict
(
configs
),
"versions"
:
dict
(
versions
),
"versions"
:
dict
(
versions
),
"samples"
:
samples
,
"samples"
:
samples
if
log_samples
else
{}
,
}
}
else
:
else
:
...
...
main.py
View file @
5fbc3f86
...
@@ -43,6 +43,7 @@ def parse_args():
...
@@ -43,6 +43,7 @@ def parse_args():
parser
.
add_argument
(
"--decontamination_ngrams_path"
,
default
=
None
)
parser
.
add_argument
(
"--decontamination_ngrams_path"
,
default
=
None
)
parser
.
add_argument
(
"--check_integrity"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--check_integrity"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--write_out"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--write_out"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--log_samples"
,
action
=
"store_true"
,
default
=
True
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -89,10 +90,12 @@ def main():
...
@@ -89,10 +90,12 @@ def main():
decontamination_ngrams_path
=
args
.
decontamination_ngrams_path
,
decontamination_ngrams_path
=
args
.
decontamination_ngrams_path
,
check_integrity
=
args
.
check_integrity
,
check_integrity
=
args
.
check_integrity
,
write_out
=
args
.
write_out
,
write_out
=
args
.
write_out
,
log_samples
=
args
.
log_samples
,
)
)
if
results
is
not
None
:
if
results
is
not
None
:
samples
=
results
.
pop
(
"samples"
)
if
args
.
log_samples
:
samples
=
results
.
pop
(
"samples"
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
,
default
=
lambda
o
:
str
(
o
))
dumped
=
json
.
dumps
(
results
,
indent
=
2
,
default
=
lambda
o
:
str
(
o
))
print
(
dumped
)
print
(
dumped
)
...
@@ -104,19 +107,20 @@ def main():
...
@@ -104,19 +107,20 @@ def main():
with
open
(
args
.
output_path
,
"w"
)
as
f
:
with
open
(
args
.
output_path
,
"w"
)
as
f
:
f
.
write
(
dumped
)
f
.
write
(
dumped
)
for
task_name
,
config
in
results
[
"configs"
].
items
():
if
args
.
log_samples
:
output_name
=
"{}_{}"
.
format
(
for
task_name
,
config
in
results
[
"configs"
].
items
():
re
.
sub
(
"/"
,
"__"
,
args
.
model_args
),
task_name
output_name
=
"{}_{}"
.
format
(
)
re
.
sub
(
"/"
,
"__"
,
args
.
model_args
),
task_name
if
os
.
path
.
isdir
(
args
.
output_path
):
filename
=
f
"./
{
args
.
output_path
}
/
{
output_name
}
.jsonl"
elif
os
.
path
.
isfile
(
args
.
output_path
):
filename
=
(
f
"./
{
os
.
path
.
dirname
(
args
.
output_path
)
}
/
{
output_name
}
.jsonl"
)
)
if
os
.
path
.
isdir
(
args
.
output_path
):
with
jsonlines
.
open
(
filename
,
"w"
)
as
f
:
filename
=
f
"./
{
args
.
output_path
}
/
{
output_name
}
.jsonl"
f
.
write_all
(
samples
[
task_name
])
elif
os
.
path
.
isfile
(
args
.
output_path
):
filename
=
(
f
"./
{
os
.
path
.
dirname
(
args
.
output_path
)
}
/
{
output_name
}
.jsonl"
)
with
jsonlines
.
open
(
filename
,
"w"
)
as
f
:
f
.
write_all
(
samples
[
task_name
])
print
(
print
(
f
"
{
args
.
model
}
(
{
args
.
model_args
}
), limit:
{
args
.
limit
}
, num_fewshot:
{
args
.
num_fewshot
}
, "
f
"
{
args
.
model
}
(
{
args
.
model_args
}
), limit:
{
args
.
limit
}
, num_fewshot:
{
args
.
num_fewshot
}
, "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment