Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
6ba2a2b0
Unverified
Commit
6ba2a2b0
authored
Sep 14, 2023
by
Lintang Sutawika
Committed by
GitHub
Sep 14, 2023
Browse files
Merge pull request #850 from EleutherAI/better-benchmark
parents
b65b9ca3
b155946e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
140 additions
and
52 deletions
+140
-52
lm_eval/benchmarks/__init__.py
lm_eval/benchmarks/__init__.py
+1
-1
lm_eval/benchmarks/pythia.yaml
lm_eval/benchmarks/pythia.yaml
+4
-4
lm_eval/evaluator.py
lm_eval/evaluator.py
+106
-37
lm_eval/prompts/__init__.py
lm_eval/prompts/__init__.py
+9
-1
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+15
-4
lm_eval/utils.py
lm_eval/utils.py
+3
-3
main.py
main.py
+2
-2
No files found.
lm_eval/benchmarks/__init__.py
View file @
6ba2a2b0
...
@@ -44,7 +44,7 @@ def include_benchmarks(task_dir: str) -> None:
...
@@ -44,7 +44,7 @@ def include_benchmarks(task_dir: str) -> None:
task_names
=
utils
.
pattern_match
(
task_list
,
ALL_TASKS
)
task_names
=
utils
.
pattern_match
(
task_list
,
ALL_TASKS
)
for
task
in
task_names
:
for
task
in
task_names
:
if
task
in
TASK_REGISTRY
:
if
(
task
in
TASK_REGISTRY
)
or
(
task
in
GROUP_REGISTRY
)
:
if
group
in
GROUP_REGISTRY
:
if
group
in
GROUP_REGISTRY
:
GROUP_REGISTRY
[
group
].
append
(
task
)
GROUP_REGISTRY
[
group
].
append
(
task
)
else
:
else
:
...
...
lm_eval/benchmarks/pythia.yaml
View file @
6ba2a2b0
group
:
pythia
group
:
pythia
task
:
task
:
-
lambada_openai
-
lambada_openai
-
wikitext
-
logiqa
-
piqa
-
piqa
-
sciq
-
sciq
-
w
sc
-
w
ikitext
-
winogrande
-
winogrande
-
ar
c
-
ws
c
-
logiqa
-
ai2_arc
-
blimp
-
blimp
-
hendrycksTest*
-
hendrycksTest*
lm_eval/evaluator.py
View file @
6ba2a2b0
...
@@ -120,6 +120,8 @@ def simple_evaluate(
...
@@ -120,6 +120,8 @@ def simple_evaluate(
task_obj
=
task_dict
[
task_name
]
task_obj
=
task_dict
[
task_name
]
if
type
(
task_obj
)
==
tuple
:
if
type
(
task_obj
)
==
tuple
:
group
,
task_obj
=
task_obj
group
,
task_obj
=
task_obj
if
task_obj
is
None
:
continue
config
=
task_obj
.
_config
config
=
task_obj
.
_config
if
num_fewshot
is
not
None
:
if
num_fewshot
is
not
None
:
...
@@ -209,23 +211,30 @@ def evaluate(
...
@@ -209,23 +211,30 @@ def evaluate(
samples
=
collections
.
defaultdict
(
list
)
samples
=
collections
.
defaultdict
(
list
)
# tracks all Instances/requests a model must generate output on.
# tracks all Instances/requests a model must generate output on.
requests
=
collections
.
defaultdict
(
list
)
requests
=
collections
.
defaultdict
(
list
)
#
Stores
task scores
based on task
group
ing.
#
Aggregated
task scores
presented with
group
s
aggregate
=
collections
.
defaultdict
(
dict
)
results_agg
=
collections
.
defaultdict
(
dict
)
#
tracks if a task was chosen via user selecting a group containing it
#
Aggregated groups scores only
task_
groups
=
collections
.
defaultdict
(
dict
)
groups
_agg
=
collections
.
defaultdict
(
dict
)
# stores the amount to pad out reqs per req. type so that
# stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal
# number of fwd passes per distributed rank is equal
padding_requests
=
collections
.
defaultdict
(
int
)
padding_requests
=
collections
.
defaultdict
(
int
)
# store the hierarchy to do proper ordering
# Stores group related keys and values for group-aggregation
task_hierarchy
=
collections
.
defaultdict
(
list
)
task_groups
=
collections
.
defaultdict
(
dict
)
# store the ordering of tasks and groups
task_order
=
collections
.
defaultdict
(
int
)
# store the aggregation for aggregating across tasks in the same group
sample_agg_fn
=
collections
.
defaultdict
(
dict
)
# get lists of each type of request
# get lists of each type of request
for
task_name
,
task
in
task_dict
.
items
():
for
task_name
,
task
in
task_dict
.
items
():
if
type
(
task
)
==
tuple
:
if
type
(
task
)
==
tuple
:
group
,
task
=
task
group_name
,
task
=
task
task_groups
[
task_name
]
=
group
task_hierarchy
[
group_name
].
append
(
task_name
)
aggregate
[
task_name
]
=
{}
else
:
task_hierarchy
[
task_name
]
=
[]
if
task
is
None
:
continue
versions
[
task_name
]
=
task
.
VERSION
versions
[
task_name
]
=
task
.
VERSION
configs
[
task_name
]
=
dict
(
task
.
dump_config
())
configs
[
task_name
]
=
dict
(
task
.
dump_config
())
...
@@ -301,6 +310,8 @@ def evaluate(
...
@@ -301,6 +310,8 @@ def evaluate(
for
task_name
,
task
in
task_dict
.
items
():
for
task_name
,
task
in
task_dict
.
items
():
if
type
(
task
)
==
tuple
:
if
type
(
task
)
==
tuple
:
group
,
task
=
task
group
,
task
=
task
if
task
is
None
:
continue
task
.
apply_filters
()
task
.
apply_filters
()
### Collect values of metrics on all datapoints ###
### Collect values of metrics on all datapoints ###
...
@@ -310,6 +321,8 @@ def evaluate(
...
@@ -310,6 +321,8 @@ def evaluate(
for
task_name
,
task
in
task_dict
.
items
():
for
task_name
,
task
in
task_dict
.
items
():
if
type
(
task
)
==
tuple
:
if
type
(
task
)
==
tuple
:
group
,
task
=
task
group
,
task
=
task
if
task
is
None
:
continue
# TODO: make it possible to use a different metric per filter
# TODO: make it possible to use a different metric per filter
# iterate over different filters used
# iterate over different filters used
for
key
in
task
.
instances
[
0
].
filtered_resps
.
keys
():
for
key
in
task
.
instances
[
0
].
filtered_resps
.
keys
():
...
@@ -396,27 +409,64 @@ def evaluate(
...
@@ -396,27 +409,64 @@ def evaluate(
vals
=
vals_torch
vals
=
vals_torch
if
lm
.
rank
==
0
:
if
lm
.
rank
==
0
:
### Get task ordering for correct sample-wide aggregation
group_to_task
=
{}
for
group
in
task_hierarchy
.
keys
():
if
group
not
in
task_order
:
task_order
[
group
]
=
0
if
len
(
task_hierarchy
[
group
])
>
0
:
group_to_task
[
group
]
=
task_hierarchy
[
group
].
copy
()
for
task
in
task_hierarchy
[
group
]:
if
task
in
task_order
:
task_order
[
task
]
+=
1
else
:
task_order
[
task
]
=
1
+
task_order
[
group
]
if
task
in
task_hierarchy
:
group_to_task
[
group
].
remove
(
task
)
group_to_task
[
group
].
extend
(
task_hierarchy
[
task
])
task_to_group
=
{}
for
group
in
group_to_task
:
for
task
in
group_to_task
[
group
]:
if
task
in
task_to_group
:
task_to_group
[
task
].
append
(
group
)
else
:
task_to_group
[
task
]
=
[
group
]
### Aggregate results over all datapoints ###
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
# aggregate results ; run bootstrap CIs
for
(
task_name
,
key
,
metric
),
items
in
vals
.
items
():
for
(
task_name
,
key
,
metric
),
items
in
vals
.
items
():
task
=
task_dict
[
task_name
]
task
=
task_dict
[
task_name
]
metric_key
=
metric
+
","
+
key
if
type
(
task
)
==
tuple
:
if
type
(
task
)
==
tuple
:
group
,
task
=
task
group_name
,
task
=
task
task_score
=
task
.
aggregation
()[
metric
](
items
)
else
:
results
[
task_name
][
metric
+
","
+
key
]
=
task_score
group_name
=
None
# Need to put back in results
agg_fn
=
task
.
aggregation
()[
metric
]
# pythia | acc
task_score
=
agg_fn
(
items
)
# | perplexity
# | word_perplexity
if
group_name
is
not
None
:
# | byte_perplexity
sample_metric_key
=
metric
+
"(sample agg),"
+
key
# | bits_per_byte
for
grouping
in
task_to_group
[
task_name
]:
if
task_name
in
task_groups
:
if
metric_key
in
results
[
grouping
]:
group_name
=
task_groups
[
task_name
]
results
[
grouping
][
metric_key
].
append
(
task_score
)
if
metric
in
list
(
aggregate
[
group_name
].
keys
()):
else
:
aggregate
[
group_name
][
metric
].
append
(
task_score
)
results
[
grouping
][
metric_key
]
=
[
task_score
]
else
:
aggregate
[
group_name
][
metric
]
=
[
task_score
]
if
sample_metric_key
in
results
[
grouping
]:
results
[
grouping
][
sample_metric_key
]
+=
items
else
:
results
[
grouping
][
sample_metric_key
]
=
items
.
copy
()
sample_agg_fn
[
grouping
][
sample_metric_key
]
=
agg_fn
results
[
task_name
][
metric_key
]
=
task_score
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
# so we run them less iterations. still looking for a cleaner way to do this
...
@@ -431,19 +481,38 @@ def evaluate(
...
@@ -431,19 +481,38 @@ def evaluate(
if
stderr
is
not
None
:
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
+
","
+
key
]
=
stderr
(
items
)
results
[
task_name
][
metric
+
"_stderr"
+
","
+
key
]
=
stderr
(
items
)
if
bool
(
aggregate
):
if
bool
(
results
):
for
group
in
aggregate
.
keys
():
for
task_or_group
in
results
.
keys
():
for
metric
in
aggregate
[
group
].
keys
():
for
metric
in
results
[
task_or_group
].
keys
():
aggregate
[
group
][
metric
]
=
np
.
average
(
aggregate
[
group
][
metric
])
if
type
(
results
[
task_or_group
][
metric
])
==
list
:
versions
[
group
]
=
"N/A"
if
"(sample agg)"
in
metric
:
results
[
task_or_group
][
metric
]
=
sample_agg_fn
[
task_or_group
][
metric
](
results
[
task_or_group
][
metric
])
else
:
results
[
task_or_group
][
metric
]
=
np
.
average
(
results
[
task_or_group
][
metric
]
)
versions
[
task_or_group
]
=
"N/A"
for
task_name
,
task
in
task_dict
.
items
():
if
type
(
task
)
==
tuple
:
group_name
,
task
=
task
order
=
task_order
[
group_name
]
tabbed_name
=
"-"
*
order
+
group_name
results_agg
[
tabbed_name
]
=
results
[
group_name
]
versions
[
tabbed_name
]
=
versions
[
group_name
]
if
order
==
0
:
groups_agg
[
group_name
]
=
results
[
group_name
]
order
=
task_order
[
task_name
]
tabbed_name
=
"-"
*
order
+
task_name
results_agg
[
tabbed_name
]
=
results
[
task_name
]
versions
[
tabbed_name
]
=
versions
[
task_name
]
results_dict
=
{
results_dict
=
{
"results"
:
dict
(
sorted
(
results
.
items
())),
"results"
:
dict
(
results_agg
.
items
()),
**
(
**
({
"groups"
:
dict
(
groups_agg
.
items
())}
if
bool
(
groups_agg
)
else
{}),
{
"aggregate"
:
dict
(
sorted
(
aggregate
.
items
()))}
if
bool
(
aggregate
)
else
{}
),
"configs"
:
dict
(
sorted
(
configs
.
items
())),
"configs"
:
dict
(
sorted
(
configs
.
items
())),
"versions"
:
dict
(
sorted
(
versions
.
items
())),
"versions"
:
dict
(
sorted
(
versions
.
items
())),
}
}
...
...
lm_eval/prompts/__init__.py
View file @
6ba2a2b0
import
ast
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
...
@@ -63,6 +65,12 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa
...
@@ -63,6 +65,12 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa
else
:
else
:
prompts
=
DatasetTemplates
(
dataset_name
=
dataset_name
,
subset_name
=
subset_name
)
prompts
=
DatasetTemplates
(
dataset_name
=
dataset_name
,
subset_name
=
subset_name
)
category_name
,
prompt_name
=
use_prompt
.
split
(
":"
)
category_name
,
*
prompt_name
=
use_prompt
.
split
(
":"
)
# TODO allow to multiple prompt naming
# if len(prompt_name) > 1:
# prompt_list = []
# for prompt in prompt_name:
# prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
# else:
prompt_list
=
utils
.
pattern_match
(
prompt_name
,
prompts
.
all_template_names
)
prompt_list
=
utils
.
pattern_match
(
prompt_name
,
prompts
.
all_template_names
)
return
[
":"
.
join
([
category_name
,
prompt
])
for
prompt
in
prompt_list
]
return
[
":"
.
join
([
category_name
,
prompt
])
for
prompt
in
prompt_list
]
lm_eval/tasks/__init__.py
View file @
6ba2a2b0
...
@@ -136,6 +136,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
...
@@ -136,6 +136,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
task_name_from_config_dict
=
{}
task_name_from_config_dict
=
{}
task_name_from_object_dict
=
{}
task_name_from_object_dict
=
{}
if
type
(
task_name_list
)
!=
list
:
task_name_list
=
[
task_name_list
]
for
task_element
in
task_name_list
:
for
task_element
in
task_name_list
:
if
isinstance
(
task_element
,
str
):
if
isinstance
(
task_element
,
str
):
...
@@ -143,12 +146,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
...
@@ -143,12 +146,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
group_name
=
task_element
group_name
=
task_element
for
task_name
in
GROUP_REGISTRY
[
task_element
]:
for
task_name
in
GROUP_REGISTRY
[
task_element
]:
if
task_name
not
in
task_name_from_registry_dict
:
if
task_name
not
in
task_name_from_registry_dict
:
task_obj
=
get_task_dict
(
task_name
)
if
task_name
in
task_obj
.
keys
():
task_dict
=
{
task_name
:
(
group_name
,
task_obj
[
task_name
]),
}
else
:
task_dict
=
{
task_name
:
(
group_name
,
None
),
**
task_obj
,
}
task_name_from_registry_dict
=
{
task_name_from_registry_dict
=
{
**
task_name_from_registry_dict
,
**
task_name_from_registry_dict
,
task_name
:
(
**
task_dict
,
group_name
,
get_task
(
task_name
=
task_name
,
config
=
config
),
),
}
}
else
:
else
:
task_name
=
task_element
task_name
=
task_element
...
...
lm_eval/utils.py
View file @
6ba2a2b0
...
@@ -267,9 +267,9 @@ def make_table(result_dict, column: str = "results"):
...
@@ -267,9 +267,9 @@ def make_table(result_dict, column: str = "results"):
from
pytablewriter
import
MarkdownTableWriter
,
LatexTableWriter
from
pytablewriter
import
MarkdownTableWriter
,
LatexTableWriter
if
column
==
"results"
:
if
column
==
"results"
:
column_name
=
"Task"
column_name
=
"Task
s
"
elif
column
==
"
aggregate
"
:
elif
column
==
"
groups
"
:
column_name
=
"
Benchmark
"
column_name
=
"
Groups
"
md_writer
=
MarkdownTableWriter
()
md_writer
=
MarkdownTableWriter
()
latex_writer
=
LatexTableWriter
()
latex_writer
=
LatexTableWriter
()
...
...
main.py
View file @
6ba2a2b0
...
@@ -209,8 +209,8 @@ def main() -> None:
...
@@ -209,8 +209,8 @@ def main() -> None:
f
"batch_size:
{
args
.
batch_size
}{
f
' (
{
batch_sizes
}
)
' if batch_sizes else ''
}
"
f
"batch_size:
{
args
.
batch_size
}{
f
' (
{
batch_sizes
}
)
' if batch_sizes else ''
}
"
)
)
print
(
evaluator
.
make_table
(
results
))
print
(
evaluator
.
make_table
(
results
))
if
"
aggregate
"
in
results
:
if
"
groups
"
in
results
:
print
(
evaluator
.
make_table
(
results
,
"
aggregate
"
))
print
(
evaluator
.
make_table
(
results
,
"
groups
"
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment