Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
7dec84a0
Commit
7dec84a0
authored
Jun 15, 2023
by
gk
Browse files
Merge branch 'big-refactor' of github.com:EleutherAI/lm-evaluation-harness into big-refactor-merge
parents
e495e3a0
0c53ff49
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
19 deletions
+35
-19
lm_eval/api/registry.py
lm_eval/api/registry.py
+3
-1
lm_eval/api/task.py
lm_eval/api/task.py
+23
-13
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+4
-0
main.py
main.py
+5
-5
No files found.
lm_eval/api/registry.py
View file @
7dec84a0
...
@@ -31,7 +31,7 @@ def get_model(model_name):
...
@@ -31,7 +31,7 @@ def get_model(model_name):
TASK_REGISTRY
=
{}
TASK_REGISTRY
=
{}
GROUP_REGISTRY
=
{}
GROUP_REGISTRY
=
{}
ALL_TASKS
=
[]
ALL_TASKS
=
set
()
func2task_index
=
{}
func2task_index
=
{}
...
@@ -42,6 +42,7 @@ def register_task(name):
...
@@ -42,6 +42,7 @@ def register_task(name):
),
f
"task named '
{
name
}
' conflicts with existing registered task!"
),
f
"task named '
{
name
}
' conflicts with existing registered task!"
TASK_REGISTRY
[
name
]
=
fn
TASK_REGISTRY
[
name
]
=
fn
ALL_TASKS
.
add
(
name
)
func2task_index
[
fn
.
__name__
]
=
name
func2task_index
[
fn
.
__name__
]
=
name
return
fn
return
fn
...
@@ -55,6 +56,7 @@ def register_group(name):
...
@@ -55,6 +56,7 @@ def register_group(name):
GROUP_REGISTRY
[
name
].
append
(
func_name
)
GROUP_REGISTRY
[
name
].
append
(
func_name
)
else
:
else
:
GROUP_REGISTRY
[
name
]
=
[
func_name
]
GROUP_REGISTRY
[
name
]
=
[
func_name
]
ALL_TASKS
.
add
(
name
)
return
fn
return
fn
return
decorate
return
decorate
...
...
lm_eval/api/task.py
View file @
7dec84a0
...
@@ -98,7 +98,9 @@ class TaskConfig(dict):
...
@@ -98,7 +98,9 @@ class TaskConfig(dict):
self
.
gold_alias
=
self
.
template_aliases
+
self
.
doc_to_target
self
.
gold_alias
=
self
.
template_aliases
+
self
.
doc_to_target
if
self
.
generation_kwargs
or
self
.
output_type
==
"greedy_until"
:
if
self
.
generation_kwargs
or
self
.
output_type
==
"greedy_until"
:
assert
self
.
output_type
==
"greedy_until"
,
"passed `generation_kwargs`, but not using a generation request type!"
assert
(
self
.
output_type
==
"greedy_until"
),
"passed `generation_kwargs`, but not using a generation request type!"
# ensure that we greedily generate in absence of explicit arguments otherwise
# ensure that we greedily generate in absence of explicit arguments otherwise
self
.
generation_kwargs
=
{
"do_sample"
:
False
,
"temperature"
:
0.0
}
self
.
generation_kwargs
=
{
"do_sample"
:
False
,
"temperature"
:
0.0
}
...
@@ -532,7 +534,7 @@ class ConfigurableTask(Task):
...
@@ -532,7 +534,7 @@ class ConfigurableTask(Task):
}
}
try
:
try
:
self
.
_metric_fn_list
[
metric_name
]
=
METRIC_REGISTRY
[
metric_name
]
self
.
_metric_fn_list
[
metric_name
]
=
METRIC_REGISTRY
[
metric_name
]
except
:
except
Exception
:
eval_logger
.
warning
(
eval_logger
.
warning
(
f
"Metric
{
metric_name
}
not found, "
f
"Metric
{
metric_name
}
not found, "
"Searching from https://huggingface.co/evaluate-metric"
"Searching from https://huggingface.co/evaluate-metric"
...
@@ -550,15 +552,24 @@ class ConfigurableTask(Task):
...
@@ -550,15 +552,24 @@ class ConfigurableTask(Task):
if
"aggregation"
in
metric_config
:
if
"aggregation"
in
metric_config
:
agg_name
=
metric_config
[
"aggregation"
]
agg_name
=
metric_config
[
"aggregation"
]
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
agg_name
]
if
type
(
agg_name
)
==
str
:
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
agg_name
]
elif
callable
(
agg_name
):
self
.
_aggregation_list
[
metric_name
]
=
metric_config
[
"aggregation"
]
else
:
else
:
INV_AGG_REGISTRY
=
{
v
:
k
for
k
,
v
in
AGGREGATION_REGISTRY
.
items
()}
metric_agg
=
DEFAULT_AGGREGATION_REGISTRY
[
metric_name
]
eval_logger
.
warning
(
eval_logger
.
warning
(
f
"metric
{
metric_name
}
is defined, but aggregation is not"
f
"metric
{
metric_name
}
is defined, but aggregation is not. "
f
"using default aggregation for
{
metric_name
}
"
f
"using default "
f
"aggregation=
{
INV_AGG_REGISTRY
[
metric_agg
]
}
"
)
)
self
.
_aggregation_list
[
metric_name
]
=
DEFAULT_AGGREGATION_REGISTRY
[
self
.
_aggregation_list
[
metric_name
]
=
metric_agg
metric_name
]
if
"higher_is_better"
in
metric_config
:
if
"higher_is_better"
in
metric_config
:
self
.
_higher_is_better
[
metric_name
]
=
metric_config
[
self
.
_higher_is_better
[
metric_name
]
=
metric_config
[
...
@@ -566,8 +577,9 @@ class ConfigurableTask(Task):
...
@@ -566,8 +577,9 @@ class ConfigurableTask(Task):
]
]
else
:
else
:
eval_logger
.
warning
(
eval_logger
.
warning
(
f
"metric
{
metric_name
}
is defined, but higher_is_better is not"
f
"metric
{
metric_name
}
is defined, but higher_is_better is not. "
f
"using default higher_is_better for
{
metric_name
}
"
f
"using default "
f
"higher_is_better=
{
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
}
"
)
)
self
.
_higher_is_better
[
metric_name
]
=
HIGHER_IS_BETTER_REGISTRY
[
self
.
_higher_is_better
[
metric_name
]
=
HIGHER_IS_BETTER_REGISTRY
[
metric_name
metric_name
...
@@ -592,9 +604,7 @@ class ConfigurableTask(Task):
...
@@ -592,9 +604,7 @@ class ConfigurableTask(Task):
filter_pipeline
=
build_filter_ensemble
(
filter_name
,
components
)
filter_pipeline
=
build_filter_ensemble
(
filter_name
,
components
)
self
.
_filters
.
append
(
filter_pipeline
)
self
.
_filters
.
append
(
filter_pipeline
)
else
:
else
:
self
.
_filters
=
[
self
.
_filters
=
[
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])]
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])
]
if
self
.
_config
.
use_prompt
is
not
None
:
if
self
.
_config
.
use_prompt
is
not
None
:
eval_logger
.
info
(
f
"loading prompt
{
self
.
_config
.
use_prompt
}
"
)
eval_logger
.
info
(
f
"loading prompt
{
self
.
_config
.
use_prompt
}
"
)
...
...
lm_eval/tasks/__init__.py
View file @
7dec84a0
...
@@ -12,6 +12,7 @@ from lm_eval.api.registry import (
...
@@ -12,6 +12,7 @@ from lm_eval.api.registry import (
register_group
,
register_group
,
TASK_REGISTRY
,
TASK_REGISTRY
,
GROUP_REGISTRY
,
GROUP_REGISTRY
,
ALL_TASKS
,
)
)
...
@@ -41,6 +42,9 @@ def include_task_folder(task_dir):
...
@@ -41,6 +42,9 @@ def include_task_folder(task_dir):
)
)
if
"task"
in
config
:
if
"task"
in
config
:
# task_name = "{}:{}".format(
# get_task_name_from_config(config), config["task"]
# )
task_name
=
"{}"
.
format
(
config
[
"task"
])
task_name
=
"{}"
.
format
(
config
[
"task"
])
register_task
(
task_name
)(
SubClass
)
register_task
(
task_name
)(
SubClass
)
...
...
main.py
View file @
7dec84a0
...
@@ -2,10 +2,10 @@ import os
...
@@ -2,10 +2,10 @@ import os
import
json
import
json
import
argparse
import
argparse
from
lm_eval
import
tasks
,
evaluator
,
utils
from
lm_eval
import
evaluator
,
utils
from
lm_eval.api.registry
import
ALL_TASKS
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
...
@@ -13,7 +13,7 @@ def parse_args():
...
@@ -13,7 +13,7 @@ def parse_args():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model"
,
required
=
True
)
parser
.
add_argument
(
"--model"
,
required
=
True
)
parser
.
add_argument
(
"--model_args"
,
default
=
""
)
parser
.
add_argument
(
"--model_args"
,
default
=
""
)
parser
.
add_argument
(
"--tasks"
,
default
=
None
,
choices
=
utils
.
MultiChoice
(
tasks
.
ALL_TASKS
))
parser
.
add_argument
(
"--tasks"
,
default
=
None
,
choices
=
utils
.
MultiChoice
(
sorted
(
ALL_TASKS
))
)
parser
.
add_argument
(
"--config"
,
default
=
None
)
parser
.
add_argument
(
"--config"
,
default
=
None
)
parser
.
add_argument
(
"--num_fewshot"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--num_fewshot"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch_size"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--batch_size"
,
type
=
str
,
default
=
None
)
...
@@ -44,7 +44,7 @@ def main():
...
@@ -44,7 +44,7 @@ def main():
)
)
if
args
.
tasks
is
None
:
if
args
.
tasks
is
None
:
task_names
=
tasks
.
ALL_TASKS
task_names
=
ALL_TASKS
else
:
else
:
if
os
.
path
.
isdir
(
args
.
tasks
):
if
os
.
path
.
isdir
(
args
.
tasks
):
import
glob
import
glob
...
@@ -56,7 +56,7 @@ def main():
...
@@ -56,7 +56,7 @@ def main():
task_names
.
append
(
config
)
task_names
.
append
(
config
)
else
:
else
:
tasks_list
=
args
.
tasks
.
split
(
","
)
tasks_list
=
args
.
tasks
.
split
(
","
)
task_names
=
utils
.
pattern_match
(
tasks_list
,
tasks
.
ALL_TASKS
)
task_names
=
utils
.
pattern_match
(
tasks_list
,
ALL_TASKS
)
for
task
in
[
task
for
task
in
tasks_list
if
task
not
in
task_names
]:
for
task
in
[
task
for
task
in
tasks_list
if
task
not
in
task_names
]:
if
os
.
path
.
isfile
(
task
):
if
os
.
path
.
isfile
(
task
):
config
=
utils
.
load_yaml_config
(
task
)
config
=
utils
.
load_yaml_config
(
task
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment