Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
7aee2dff
Commit
7aee2dff
authored
May 15, 2023
by
lintangsutawika
Browse files
moved functions out and some fixes
parent
66c58194
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
45 deletions
+14
-45
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+14
-45
No files found.
lm_eval/tasks/__init__.py
View file @
7aee2dff
import
os
import
os
import
re
import
yaml
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
.arc
import
*
from
.arc
import
*
from
lm_eval
import
utils
from
lm_eval.api.task
import
TaskConfig
,
Task
,
ConfigurableTask
from
lm_eval.api.task
import
TaskConfig
,
Task
,
ConfigurableTask
from
lm_eval.api.register
import
(
from
lm_eval.api.register
import
(
register_task
,
register_task
,
...
@@ -14,39 +13,8 @@ from lm_eval.api.register import (
...
@@ -14,39 +13,8 @@ from lm_eval.api.register import (
)
)
def
load_yaml_config
(
yaml_path
):
def
get_task_name_from_config
(
task_config
):
with
open
(
yaml_path
,
'rb'
)
as
file
:
return
"configurable_{dataset_path}_{dataset_name}"
.
format
(
**
task_config
)
yaml_config
=
yaml
.
full_load
(
file
)
yaml_dir
=
os
.
path
.
dirname
(
yaml_path
)
if
'include'
in
yaml_config
:
include_path
=
yaml_config
[
'include'
]
del
yaml_config
[
'include'
]
if
type
(
include_path
)
==
str
:
include_path
=
[
include_path
]
# Load from the last one first
include_path
.
reverse
()
final_yaml_config
=
{}
for
path
in
include_path
:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if
not
os
.
path
.
isfile
(
path
):
path
=
os
.
path
.
join
(
yaml_dir
,
path
)
try
:
included_yaml_config
=
load_yaml_config
(
path
)
final_yaml_config
.
update
(
included_yaml_config
)
except
:
# If failed to load, ignore
pass
final_yaml_config
.
update
(
yaml_config
)
return
final_yaml_config
return
yaml_config
task_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/"
task_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/"
...
@@ -56,7 +24,7 @@ for root, subdirs, file_list in os.walk(task_dir):
...
@@ -56,7 +24,7 @@ for root, subdirs, file_list in os.walk(task_dir):
if
"yaml"
in
file
:
if
"yaml"
in
file
:
yaml_path
=
os
.
path
.
join
(
root
,
file
)
yaml_path
=
os
.
path
.
join
(
root
,
file
)
try
:
try
:
config
=
load_yaml_config
(
yaml_path
)
config
=
utils
.
load_yaml_config
(
yaml_path
)
SubClass
=
type
(
SubClass
=
type
(
config
[
'task'
]
+
'ConfigurableTask'
,
config
[
'task'
]
+
'ConfigurableTask'
,
...
@@ -65,13 +33,17 @@ for root, subdirs, file_list in os.walk(task_dir):
...
@@ -65,13 +33,17 @@ for root, subdirs, file_list in os.walk(task_dir):
)
)
if
'task'
in
config
:
if
'task'
in
config
:
register_task
(
config
[
'task'
])(
SubClass
)
task_name
=
"{}:{}"
.
format
(
get_task_name_from_config
(
config
),
config
[
'task'
]
)
register_task
(
task_name
)(
SubClass
)
if
'group'
in
config
:
if
'group'
in
config
:
for
group
in
config
[
'group'
]:
for
group
in
config
[
'group'
]:
register_group
(
group
)(
SubClass
)
register_group
(
group
)(
SubClass
)
except
:
except
:
pass
pass
TASK_REGISTRY
=
task_registry
TASK_REGISTRY
=
task_registry
GROUP_REGISTRY
=
group_registry
GROUP_REGISTRY
=
group_registry
...
@@ -100,10 +72,6 @@ def get_task_name_from_object(task_object):
...
@@ -100,10 +72,6 @@ def get_task_name_from_object(task_object):
)
)
def
get_task_name_from_config
(
task_config
):
return
"configurable_{dataset_path}_{dataset_name}"
.
format
(
**
task_config
)
# TODO: pass num_fewshot and other cmdline overrides in a better way
# TODO: pass num_fewshot and other cmdline overrides in a better way
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
dict
,
Task
]],
**
kwargs
):
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
dict
,
Task
]],
**
kwargs
):
...
@@ -126,6 +94,7 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
...
@@ -126,6 +94,7 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
)
)
}
}
else
:
else
:
task_name
=
task_element
if
task_name
not
in
task_name_from_registry_dict
:
if
task_name
not
in
task_name_from_registry_dict
:
task_name_from_registry_dict
=
{
task_name_from_registry_dict
=
{
**
task_name_from_registry_dict
,
**
task_name_from_registry_dict
,
...
@@ -135,11 +104,11 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
...
@@ -135,11 +104,11 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
}
}
elif
isinstance
(
task_element
,
dict
):
elif
isinstance
(
task_element
,
dict
):
task_element
.
update
(
config
)
task_name_from_config_dict
=
{
task_name_from_config_dict
=
{
**
task_name_from_config_dict
,
**
task_name_from_config_dict
,
get_task_name_from_config
(
task_element
):
ConfigurableTask
(
get_task_name_from_config
(
task_element
):
ConfigurableTask
(
config
=
config
config
=
task_element
)
)
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment