Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
5fbc3f86
Commit
5fbc3f86
authored
Jul 14, 2023
by
haileyschoelkopf
Browse files
fix random seed issue, log_samples optional
parent
09a71562
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
43 deletions
+55
-43
lm_eval/api/task.py
lm_eval/api/task.py
+12
-15
lm_eval/evaluator.py
lm_eval/evaluator.py
+26
-15
main.py
main.py
+17
-13
No files found.
lm_eval/api/task.py
View file @
5fbc3f86
...
...
@@ -8,6 +8,7 @@ import evaluate
import
random
import
itertools
import
functools
from
tqdm
import
tqdm
import
datasets
import
numpy
as
np
...
...
@@ -217,8 +218,8 @@ class Task(abc.ABC):
self
.
_filters
.
append
(
filter_pipeline
)
self
.
sampler
=
samplers
.
Sampler
(
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
()
)
# TODO: pass the correct docs in here
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
(
1234
)
)
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
"""Downloads and returns the task dataset.
...
...
@@ -366,13 +367,18 @@ class Task(abc.ABC):
False
),
f
"Task dataset (path=
{
self
.
DATASET_PATH
}
, name=
{
self
.
DATASET_NAME
}
) must have valid or test docs!"
eval_logger
.
info
(
f
"Building contexts for task '
{
self
.
_config
.
task
}
' on rank
{
rank
}
..."
)
instances
=
[]
for
doc_id
,
doc
in
utils
.
create_iterator
(
enumerate
(
docs
),
rank
,
world_size
,
limit
):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx
=
self
.
fewshot_context
(
doc
,
self
.
_config
.
num_fewshot
,
rnd
=
random
.
Random
()
doc
,
self
.
_config
.
num_fewshot
,
)
# TODO: we should override self._config.repeats if doing greedy gen so users don't waste time+compute
...
...
@@ -453,7 +459,7 @@ class Task(abc.ABC):
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
@
utils
.
positional_deprecated
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
rnd
=
None
):
def
fewshot_context
(
self
,
doc
,
num_fewshot
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
...
...
@@ -461,15 +467,9 @@ class Task(abc.ABC):
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:returns: str
The fewshot context.
"""
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`"
if
num_fewshot
==
0
:
# always prepend the (possibly empty) task description
...
...
@@ -625,7 +625,7 @@ class ConfigurableTask(Task):
if
self
.
fewshot_docs
()
is
not
None
:
self
.
sampler
=
samplers
.
Sampler
(
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
()
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
(
1234
)
)
def
download
(
self
,
dataset_kwargs
=
None
):
...
...
@@ -1004,13 +1004,10 @@ class PerplexityTask(Task):
assert
k
==
0
return
[]
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
rnd
=
None
):
def
fewshot_context
(
self
,
doc
,
num_fewshot
):
assert
(
num_fewshot
==
0
),
"The number of fewshot examples must be 0 for perplexity tasks."
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`."
return
""
...
...
lm_eval/evaluator.py
View file @
5fbc3f86
...
...
@@ -45,6 +45,7 @@ def simple_evaluate(
check_integrity
=
False
,
decontamination_ngrams_path
=
None
,
write_out
=
False
,
log_samples
=
True
,
):
"""Instantiate and evaluate a model on a list of tasks.
...
...
@@ -72,12 +73,17 @@ def simple_evaluate(
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:param write_out: bool
If True, write details about prompts and logits to json for all tasks
If True, write out an example document and model input for checking task integrity
:param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:return
Dictionary of results
"""
random
.
seed
(
1234
)
random
.
seed
(
0
)
np
.
random
.
seed
(
1234
)
torch
.
manual_seed
(
1234
)
# TODO: this may affect training runs that are run with evaluation mid-run.
assert
tasks
!=
[],
"No tasks specified"
...
...
@@ -118,6 +124,7 @@ def simple_evaluate(
bootstrap_iters
=
bootstrap_iters
,
decontamination_ngrams_path
=
decontamination_ngrams_path
,
write_out
=
write_out
,
log_samples
=
log_samples
,
)
if
lm
.
rank
==
0
:
...
...
@@ -154,6 +161,7 @@ def evaluate(
bootstrap_iters
=
100000
,
decontamination_ngrams_path
=
None
,
write_out
=
False
,
log_samples
=
True
,
):
"""Instantiate and evaluate a model on a list of tasks.
...
...
@@ -168,7 +176,9 @@ def evaluate(
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param write_out: bool
If True, write all prompts, logits and metrics to json for offline analysis
If True, write out an example document and model input for checking task integrity
:param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:return
Dictionary of results
"""
...
...
@@ -282,17 +292,18 @@ def evaluate(
metrics
=
task
.
process_results
(
doc
,
[
req
.
filtered_resps
[
key
]
for
req
in
requests
]
)
target
=
task
.
doc_to_target
(
doc
)
example
=
{
"doc_id"
:
doc_id
,
"doc"
:
doc
,
"target"
:
target
,
"arguments"
:
[
req
.
args
for
req
in
requests
],
"resps"
:
[
req
.
resps
for
req
in
requests
],
"filtered_resps"
:
[
req
.
filtered_resps
[
key
]
for
req
in
requests
],
}
example
.
update
(
metrics
)
samples
[
task_name
].
append
(
example
)
if
log_samples
:
target
=
task
.
doc_to_target
(
doc
)
example
=
{
"doc_id"
:
doc_id
,
"doc"
:
doc
,
"target"
:
target
,
"arguments"
:
[
req
.
args
for
req
in
requests
],
"resps"
:
[
req
.
resps
for
req
in
requests
],
"filtered_resps"
:
[
req
.
filtered_resps
[
key
]
for
req
in
requests
],
}
example
.
update
(
metrics
)
samples
[
task_name
].
append
(
example
)
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_name
,
key
,
metric
)].
append
(
value
)
...
...
@@ -366,7 +377,7 @@ def evaluate(
"results"
:
dict
(
results
),
"configs"
:
dict
(
configs
),
"versions"
:
dict
(
versions
),
"samples"
:
samples
,
"samples"
:
samples
if
log_samples
else
{}
,
}
else
:
...
...
main.py
View file @
5fbc3f86
...
...
@@ -43,6 +43,7 @@ def parse_args():
parser
.
add_argument
(
"--decontamination_ngrams_path"
,
default
=
None
)
parser
.
add_argument
(
"--check_integrity"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--write_out"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--log_samples"
,
action
=
"store_true"
,
default
=
True
)
return
parser
.
parse_args
()
...
...
@@ -89,10 +90,12 @@ def main():
decontamination_ngrams_path
=
args
.
decontamination_ngrams_path
,
check_integrity
=
args
.
check_integrity
,
write_out
=
args
.
write_out
,
log_samples
=
args
.
log_samples
,
)
if
results
is
not
None
:
samples
=
results
.
pop
(
"samples"
)
if
args
.
log_samples
:
samples
=
results
.
pop
(
"samples"
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
,
default
=
lambda
o
:
str
(
o
))
print
(
dumped
)
...
...
@@ -104,19 +107,20 @@ def main():
with
open
(
args
.
output_path
,
"w"
)
as
f
:
f
.
write
(
dumped
)
for
task_name
,
config
in
results
[
"configs"
].
items
():
output_name
=
"{}_{}"
.
format
(
re
.
sub
(
"/"
,
"__"
,
args
.
model_args
),
task_name
)
if
os
.
path
.
isdir
(
args
.
output_path
):
filename
=
f
"./
{
args
.
output_path
}
/
{
output_name
}
.jsonl"
elif
os
.
path
.
isfile
(
args
.
output_path
):
filename
=
(
f
"./
{
os
.
path
.
dirname
(
args
.
output_path
)
}
/
{
output_name
}
.jsonl"
if
args
.
log_samples
:
for
task_name
,
config
in
results
[
"configs"
].
items
():
output_name
=
"{}_{}"
.
format
(
re
.
sub
(
"/"
,
"__"
,
args
.
model_args
),
task_name
)
with
jsonlines
.
open
(
filename
,
"w"
)
as
f
:
f
.
write_all
(
samples
[
task_name
])
if
os
.
path
.
isdir
(
args
.
output_path
):
filename
=
f
"./
{
args
.
output_path
}
/
{
output_name
}
.jsonl"
elif
os
.
path
.
isfile
(
args
.
output_path
):
filename
=
(
f
"./
{
os
.
path
.
dirname
(
args
.
output_path
)
}
/
{
output_name
}
.jsonl"
)
with
jsonlines
.
open
(
filename
,
"w"
)
as
f
:
f
.
write_all
(
samples
[
task_name
])
print
(
f
"
{
args
.
model
}
(
{
args
.
model_args
}
), limit:
{
args
.
limit
}
, num_fewshot:
{
args
.
num_fewshot
}
, "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment