Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
ca0b8d45
Commit
ca0b8d45
authored
May 13, 2023
by
lintangsutawika
Browse files
modified how yaml and python functions are added to groups and task registry
parent
275857a1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
156 additions
and
57 deletions
+156
-57
lm_eval/api/register.py
lm_eval/api/register.py
+50
-0
lm_eval/api/task.py
lm_eval/api/task.py
+5
-5
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+101
-52
No files found.
lm_eval/api/register.py
0 → 100644
View file @
ca0b8d45
import
os
task_registry
=
{}
group_registry
=
{}
task2func_index
=
{}
func2task_index
=
{}
def
register_task
(
name
):
def
wrapper
(
func
):
task_registry
[
name
]
=
func
func2task_index
[
func
.
__name__
]
=
name
task2func_index
[
name
]
=
func
.
__name__
return
func
return
wrapper
def
register_group
(
name
):
def
wrapper
(
func
):
func_name
=
func2task_index
[
func
.
__name__
]
if
name
in
group_registry
:
group_registry
[
name
].
append
(
func_name
)
else
:
group_registry
[
name
]
=
[
func_name
]
return
func
return
wrapper
# @register_group('group_a')
# @register_task('a')
# def foo():
# pass
# @register_group('group_a')
# @register_task('b')
# def fii():
# pass
# @register_group('group_b')
# @register_task('c')
# def bar():
# pass
# name = 'A' # or args.type
# func_to_call = REGISTER[name]
# func_to_call() # actual call is done here
\ No newline at end of file
lm_eval/api/task.py
View file @
ca0b8d45
...
...
@@ -26,10 +26,10 @@ from lm_eval.filters import build_filter_ensemble
@
dataclass
class
TaskConfig
(
yaml
.
YAMLObject
):
yaml_tag
=
u
'!task'
class
TaskConfig
(
dict
):
task
:
str
=
None
group
:
str
=
None
names
:
str
=
None
reference
:
str
=
None
task_name
:
str
=
None
# TODO: deprecate this, it'll be set in __post_init__ to be names[0]
...
...
@@ -89,7 +89,6 @@ class Task(abc.ABC):
VERSION
=
None
TASK_NAME
:
str
=
None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script.
DATASET_PATH
:
str
=
None
...
...
@@ -430,7 +429,8 @@ class ConfigurableTask(Task):
self
.
_config
=
TaskConfig
(
**
config
)
# Overwrite configs
else
:
self
.
_config
.
__dict__
.
update
(
config
)
if
config
!=
None
:
self
.
_config
.
__dict__
.
update
(
config
)
if
self
.
_config
is
None
:
raise
ValueError
(
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
)
...
...
lm_eval/tasks/__init__.py
View file @
ca0b8d45
import
os
import
re
import
yaml
from
typing
import
List
,
Union
from
.vanilla
import
*
from
lm_eval.utils
import
get_yaml_config
,
register_task
from
lm_eval.api.task
import
Task
,
ConfigurableTask
YAML_REGISTRY
=
{}
FUNC_REGISTRY
=
register_task
.
all
BENCHMARK_REGISTRY
=
{}
from
.arc
import
*
# we want to register all yaml tasks in our .yaml folder.
yaml_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/"
+
"yaml"
for
yaml_file
in
sorted
(
os
.
listdir
(
yaml_dir
)):
yaml_path
=
os
.
path
.
join
(
yaml_dir
,
yaml_file
)
names
=
re
.
sub
(
r
"\."
,
"_"
,
yaml_path
.
split
(
"/"
)[
-
1
])
YAML_REGISTRY
[
names
]
=
yaml_path
from
lm_eval.api.task
import
TaskConfig
,
Task
,
ConfigurableTask
from
lm_eval.api.register
import
(
register_task
,
register_group
,
task_registry
,
group_registry
)
TASK_REGISTRY
=
list
(
YAML_REGISTRY
.
keys
())
+
list
(
FUNC_REGISTRY
.
keys
())
task_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/"
for
root
,
subdirs
,
file_list
in
os
.
walk
(
task_dir
):
if
(
subdirs
==
[])
and
(
len
(
file_list
)
>
0
):
for
file
in
file_list
:
if
"yaml"
in
file
:
yaml_path
=
os
.
path
.
join
(
root
,
file
)
try
:
config
=
yaml
.
full_load
(
open
(
yaml_path
,
"rb"
))
SubClass
=
type
(
config
[
'task'
]
+
'ConfigurableTask'
,
(
ConfigurableTask
,),
{
'CONFIG'
:
TaskConfig
(
**
config
)}
)
if
'task'
in
config
:
register_task
(
config
[
'task'
])(
SubClass
)
if
'group'
in
config
:
for
group
in
config
[
'group'
]:
register_group
(
group
)(
SubClass
)
except
:
pass
TASK_REGISTRY
=
task_registry
GROUP_REGISTRY
=
group_registry
ALL_TASKS
=
sorted
(
list
(
TASK_REGISTRY
))
def
get_task
(
task_name
,
task_config
):
if
task_name
in
TASK_REGISTRY
:
if
task_name
in
YAML_REGISTRY
:
return
ConfigurableTask
(
config
=
{
**
get_yaml_config
(
YAML_REGISTRY
[
task_name
]),
**
task_config
}
)
elif
task_name
in
FUNC_REGISTRY
:
return
FUNC_REGISTRY
[
task_name
](
config
=
task_config
)
else
:
def
get_task
(
task_name
,
config
):
try
:
return
TASK_REGISTRY
[
task_name
](
config
)
except
KeyError
:
print
(
"Available tasks:"
)
pprint
(
TASK_REGISTRY
)
raise
KeyError
(
f
"Missing task
{
task_name
}
"
)
...
...
@@ -63,33 +69,76 @@ def get_task_name_from_config(task_config):
# TODO: pass num_fewshot and other cmdline overrides in a better way
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
dict
,
Task
]],
num_fewshot
=
None
):
task_name_from_registry_dict
=
{
task_name
:
get_task
(
task_name
=
task_name
,
task_config
=
{
"num_fewshot"
:
num_fewshot
if
num_fewshot
else
0
}
)
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
}
task_name_from_config_dict
=
{
get_task_name_from_config
(
task_config
):
ConfigurableTask
(
config
=
task_config
)
for
task_config
in
task_name_list
if
isinstance
(
task_config
,
dict
)
}
# TODO: Do we still need this?
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
dict
,
Task
]],
config
,
**
kwargs
):
task_name_from_registry_dict
=
{}
task_name_from_config_dict
=
{}
task_name_from_object_dict
=
{}
for
task_element
in
task_name_list
:
if
isinstance
(
task_element
,
str
):
if
task_element
in
GROUP_REGISTRY
:
for
task_name
in
GROUP_REGISTRY
[
task_element
]:
if
task_name
not
in
task_name_from_registry_dict
:
task_name_from_registry_dict
=
{
**
task_name_from_registry_dict
,
task_name
:
get_task
(
task_name
=
task_name
,
config
=
config
)
}
else
:
if
task_name
not
in
task_name_from_registry_dict
:
task_name_from_registry_dict
=
{
**
task_name_from_registry_dict
,
task_name
:
get_task
(
task_name
=
task_element
,
config
=
config
)
}
elif
isinstance
(
task_element
,
dict
):
task_name_from_config_dict
=
{
**
task_name_from_config_dict
,
get_task_name_from_config
(
task_element
):
ConfigurableTask
(
config
=
config
)
}
elif
isinstance
(
task_element
,
Task
):
task_name_from_object_dict
=
{
**
task_name_from_object_dict
,
get_task_name_from_object
(
task_element
):
task_element
}
# task_name_from_registry_dict = {
# task_name: get_task(
# task_name=task_name,
# task_config=config
# )
# for group_name in task_name_list for task_name in GROUP_REGISTRY[group_name]
# if (isinstance(group_name, str)) and (group_name in GROUP_REGISTRY)
# }
# task_name_from_config_dict = {
# get_task_name_from_config(task_config): ConfigurableTask(
# config=task_config
# )
# for task_config in task_name_list
# if isinstance(task_config, dict)
# }
# # TODO: Do we still need this?
# task_name_from_object_dict = {
# get_task_name_from_object(task_object): task_object
# for task_object in task_name_list
# if isinstance(task_object, Task)
# }
# assert set(task_name_from_registry_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
assert
set
(
task_name_from_registry_dict
.
keys
()).
isdisjoint
(
set
(
task_name_from_object_dict
.
keys
()))
return
{
**
task_name_from_registry_dict
,
**
task_name_from_config_dict
,
#
**task_name_from_object_dict,
**
task_name_from_object_dict
,
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment