Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
28cc5b6e
Commit
28cc5b6e
authored
Jan 20, 2024
by
lintangsutawika
Browse files
indexing and loading are part of a task_manager object
parent
17172a26
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
95 deletions
+121
-95
lm_eval/__main__.py
lm_eval/__main__.py
+8
-22
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+113
-73
No files found.
lm_eval/__main__.py
View file @
28cc5b6e
...
...
@@ -9,7 +9,7 @@ from typing import Union
import
numpy
as
np
from
lm_eval
import
evaluator
,
utils
from
lm_eval.tasks
import
initialize_tasks
,
load_task_or_group
from
lm_eval.tasks
import
TaskManager
from
lm_eval.utils
import
make_table
...
...
@@ -155,7 +155,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
# initialize_tasks(args.verbosity)
ALL_TASKS
=
initialize_tasks
(
args
.
verbosity
,
include_path
=
args
.
include_path
)
task_manager
=
TaskManager
(
args
.
verbosity
,
include_path
=
args
.
include_path
)
if
args
.
limit
:
eval_logger
.
warning
(
...
...
@@ -170,7 +170,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
sys
.
exit
()
elif
args
.
tasks
==
"list"
:
eval_logger
.
info
(
"Available Tasks:
\n
- {}"
.
format
(
"
\n
- "
.
join
(
sorted
(
ALL_TASKS
.
key
s
()))
)
"Available Tasks:
\n
- {}"
.
format
(
"
\n
- "
.
join
(
task_manager
.
all_task
s
()))
)
else
:
if
os
.
path
.
isdir
(
args
.
tasks
):
...
...
@@ -183,7 +183,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
loaded_task_list
.
append
(
config
)
else
:
input_task_list
=
args
.
tasks
.
split
(
","
)
loaded_task_list
=
utils
.
pattern_match
(
input_task_list
,
ALL_TASKS
.
key
s
())
loaded_task_list
=
utils
.
pattern_match
(
input_task_list
,
task_manager
.
all_task
s
())
for
task
in
[
task
for
task
in
input_task_list
if
task
not
in
loaded_task_list
]:
...
...
@@ -229,25 +229,11 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
eval_logger
.
info
(
f
"Selected Tasks:
{
loaded_task_list
}
"
)
eval_logger
.
info
(
"Loading selected tasks..."
)
all_tasks
=
{}
for
task
in
loaded_task_list
:
task_object
=
load_task_or_group
(
ALL_TASKS
,
task_name_or_config
=
task
,
)
if
isinstance
(
task
,
str
):
task_name
=
task
elif
isinstance
(
task
,
dict
):
task_name
=
task
[
"task"
]
if
isinstance
(
task_object
,
dict
):
all_tasks
=
{
**
task_object
,
**
all_tasks
}
else
:
all_tasks
[
task_name
]
=
task_object
all_tasks
=
task_manager
.
load_task_or_group
(
loaded_task_list
)
#
for key, value in all_tasks.items():
#
print(key, value)
#
import sys; sys.exit()
for
key
,
value
in
all_tasks
.
items
():
print
(
key
,
value
)
import
sys
;
sys
.
exit
()
results
=
evaluator
.
simple_evaluate
(
model
=
args
.
model
,
...
...
lm_eval/tasks/__init__.py
View file @
28cc5b6e
import
os
import
abc
import
yaml
import
collections
from
typing
import
List
,
Union
,
Dict
...
...
@@ -11,7 +12,7 @@ from lm_eval.api.registry import (
register_group
,
TASK_REGISTRY
,
GROUP_REGISTRY
,
ALL_TASKS
,
self
.
ALL_TASKS
,
)
import
logging
...
...
@@ -35,11 +36,28 @@ def is_group(task):
return
True
return
False
class
TaskManager
(
abc
.
ABC
):
def
load_task_or_group
(
ALL_TASKS
,
task_name_or_config
:
Union
[
str
,
dict
]
=
None
)
->
ConfigurableTask
:
def
__init__
(
self
,
verbosity
=
"INFO"
,
include_path
=
None
)
->
None
:
self
.
ALL_TASKS
=
initialize_tasks
(
verbosity
=
verbosity
,
include_path
=
include_path
)
@
property
def
all_tasks
(
self
):
return
sorted
(
self
.
ALL_TASKS
.
keys
())
def
_load_individual_task_or_group
(
self
,
task_name_or_config
:
Union
[
str
,
dict
]
=
None
)
->
ConfigurableTask
:
print
(
"Loading"
,
task_name_or_config
)
if
isinstance
(
task_name_or_config
,
str
):
task_info
=
ALL_TASKS
[
task_name_or_config
]
task_info
=
self
.
ALL_TASKS
[
task_name_or_config
]
yaml_path
=
task_info
[
"yaml_path"
]
task_type
=
task_info
[
"type"
]
subtask_list
=
task_info
[
"task"
]
if
"task"
in
task_info
else
-
1
...
...
@@ -58,13 +76,13 @@ def load_task_or_group(ALL_TASKS, task_name_or_config: Union[str, dict] = None)
for
task_or_config
in
subtask_list
:
if
isinstance
(
task_or_config
,
str
):
all_subtasks
[
task_or_config
]
=
(
group_name
,
None
)
task_object
=
load
_task_or_group
(
ALL_TASKS
,
task_name_or_config
=
task_or_config
)
task_object
=
self
.
_load_individual
_task_or_group
(
self
.
ALL_TASKS
,
task_name_or_config
=
task_or_config
)
elif
isinstance
(
task_or_config
,
dict
):
if
"group"
in
task_or_config
:
all_subtasks
[
task_or_config
[
"group"
]]
=
(
group_name
,
None
)
elif
"task"
in
task_or_config
:
all_subtasks
[
task_or_config
[
"task"
]]
=
(
group_name
,
None
)
task_object
=
load
_task_or_group
(
ALL_TASKS
,
task_name_or_config
=
task_or_config
)
task_object
=
self
.
_load_individual
_task_or_group
(
self
.
ALL_TASKS
,
task_name_or_config
=
task_or_config
)
if
isinstance
(
task_object
,
dict
):
all_subtasks
=
{
**
task_object
,
**
all_subtasks
}
...
...
@@ -83,10 +101,10 @@ def load_task_or_group(ALL_TASKS, task_name_or_config: Union[str, dict] = None)
all_subtasks
=
{}
for
task_or_config
in
subtask_list
:
if
isinstance
(
task_or_config
,
str
):
task_object
=
load
_task_or_group
(
ALL_TASKS
,
task_name_or_config
=
task_or_config
)
task_object
=
self
.
_load_individual
_task_or_group
(
self
.
ALL_TASKS
,
task_name_or_config
=
task_or_config
)
task_name
=
task_or_config
elif
isinstance
(
task_or_config
,
dict
):
task_object
=
load
_task_or_group
(
ALL_TASKS
,
task_name_or_config
=
task_or_config
)
task_object
=
self
.
_load_individual
_task_or_group
(
self
.
ALL_TASKS
,
task_name_or_config
=
task_or_config
)
if
isinstance
(
task_object
,
dict
):
all_subtasks
=
{
**
task_object
,
**
all_subtasks
}
...
...
@@ -97,7 +115,7 @@ def load_task_or_group(ALL_TASKS, task_name_or_config: Union[str, dict] = None)
else
:
task_type
=
"task"
task_name
=
task_name_or_config
[
"task"
]
base_task_info
=
ALL_TASKS
[
task_name
]
base_task_info
=
self
.
ALL_TASKS
[
task_name
]
base_yaml_path
=
base_task_info
[
"yaml_path"
]
base_task_config
=
utils
.
load_yaml_config
(
base_yaml_path
)
...
...
@@ -108,6 +126,28 @@ def load_task_or_group(ALL_TASKS, task_name_or_config: Union[str, dict] = None)
}
)
def
load_task_or_group
(
self
,
task_list
:
Union
[
str
,
list
]
=
None
)
->
dict
:
if
isinstance
(
task_list
,
str
):
task_list
=
[
task_list
]
all_loaded_tasks
=
{}
for
task
in
task_list
:
task_object
=
self
.
_load_individual_task_or_group
(
task_name_or_config
=
task
,
)
if
isinstance
(
task
,
str
):
task_name
=
task
elif
isinstance
(
task
,
dict
):
task_name
=
task
[
"task"
]
if
isinstance
(
task_object
,
dict
):
all_loaded_tasks
=
{
**
task_object
,
**
self
.
ALL_TASKS
}
else
:
all_loaded_tasks
[
task_name
]
=
task_object
return
all_loaded_tasks
def
register_configurable_task
(
config
:
Dict
[
str
,
str
])
->
int
:
SubClass
=
type
(
...
...
@@ -182,16 +222,16 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
GROUP_REGISTRY
[
group
].
append
(
sub_group
)
else
:
GROUP_REGISTRY
[
group
]
=
[
sub_group
]
ALL_TASKS
.
add
(
group
)
self
.
ALL_TASKS
.
add
(
group
)
task_names
=
utils
.
pattern_match
(
registered_task_or_group_list
,
ALL_TASKS
)
task_names
=
utils
.
pattern_match
(
registered_task_or_group_list
,
self
.
ALL_TASKS
)
for
task
in
task_names
:
if
(
task
in
TASK_REGISTRY
)
or
(
task
in
GROUP_REGISTRY
):
if
group
in
GROUP_REGISTRY
:
GROUP_REGISTRY
[
group
].
append
(
task
)
else
:
GROUP_REGISTRY
[
group
]
=
[
task
]
ALL_TASKS
.
add
(
group
)
self
.
ALL_TASKS
.
add
(
group
)
return
0
...
...
@@ -345,12 +385,12 @@ def initialize_tasks(verbosity="INFO", include_path=None):
include_path
=
[
include_path
]
all_paths
.
extend
(
include_path
)
ALL_TASKS
=
{}
self
.
ALL_TASKS
=
{}
for
task_dir
in
all_paths
:
tasks
=
get_task_and_group
(
task_dir
)
ALL_TASKS
=
{
**
tasks
,
**
ALL_TASKS
}
self
.
ALL_TASKS
=
{
**
tasks
,
**
self
.
ALL_TASKS
}
return
ALL_TASKS
return
self
.
ALL_TASKS
def
get_task
(
task_name
,
config
):
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment