Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
d352a549
Commit
d352a549
authored
Jan 23, 2024
by
lintangsutawika
Browse files
can load individual custom python class task
parent
671ce18a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
22 deletions
+49
-22
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+49
-20
lm_eval/tasks/squadv2/task.py
lm_eval/tasks/squadv2/task.py
+0
-2
No files found.
lm_eval/tasks/__init__.py
View file @
d352a549
...
...
@@ -12,21 +12,23 @@ from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
import
logging
# import python tasks
import
squadv2
import
scrolls
python_tasks
=
{
"squadv2"
:
squadv2
.
task
.
SQuAD2
,
"scrolls_quality"
:
scrolls
.
task
.
QuALITY
,
"scrolls_narrativeqa"
:
scrolls
.
task
.
NarrativeQA
,
"scrolls_contractnli"
:
scrolls
.
task
.
ContractNLI
,
"scrolls_govreport"
:
scrolls
.
task
.
GovReport
,
"scrolls_summscreenfd"
:
scrolls
.
task
.
SummScreenFD
,
"scrolls_qmsum"
:
scrolls
.
task
.
QMSum
,
}
#
# import python tasks
#
import squadv2
.task
#
import scrolls
.task
#
python_tasks = {
#
"squadv2": squadv2.task.SQuAD2,
#
"scrolls_quality": scrolls.task.QuALITY,
#
"scrolls_narrativeqa": scrolls.task.NarrativeQA,
#
"scrolls_contractnli": scrolls.task.ContractNLI,
#
"scrolls_govreport": scrolls.task.GovReport,
#
"scrolls_summscreenfd": scrolls.task.SummScreenFD,
#
"scrolls_qmsum": scrolls.task.QMSum,
#
}
eval_logger
=
utils
.
eval_logger
GROUP_KEYS
=
[
"group"
,
"task"
,
"weight_by_size"
]
PYTHON_TASK_KEYS
=
[
"task"
,
"class"
]
class
TaskManager
(
abc
.
ABC
):
...
...
@@ -43,7 +45,6 @@ class TaskManager(abc.ABC):
self
.
ALL_TASKS
=
self
.
initialize_tasks
(
include_path
=
include_path
)
# + {k:v, "type":"task" for k,v in python_tasks.items()}
def
initialize_tasks
(
self
,
include_path
=
None
):
...
...
@@ -69,15 +70,25 @@ class TaskManager(abc.ABC):
return
False
def
_name_is_task
(
self
,
name
):
if
self
.
_name_is_registered
(
name
)
and
(
self
.
ALL_TASKS
[
name
][
"type"
]
==
"task"
):
if
self
.
_name_is_registered
(
name
)
and
(
"task"
in
self
.
ALL_TASKS
[
name
][
"type"
]):
return
True
return
False
def
_name_is_python_task
(
self
,
name
):
if
self
.
_name_is_registered
(
name
)
and
(
self
.
ALL_TASKS
[
name
][
"type"
]
==
"python_task"
):
return
True
return
False
def
_config_is_task
(
self
,
config
):
if
set
(
config
.
keys
())
<=
[
"group"
,
"task"
,
"weight_by_size"
]
:
if
set
(
config
.
keys
())
<=
set
(
GROUP_KEYS
)
:
return
False
return
True
def
_config_is_python_task
(
self
,
config
):
if
set
(
config
.
keys
())
==
set
(
PYTHON_TASK_KEYS
):
return
True
return
False
def
_get_yaml_path
(
self
,
name
):
assert
name
in
self
.
ALL_TASKS
return
self
.
ALL_TASKS
[
name
][
"yaml_path"
]
...
...
@@ -98,18 +109,25 @@ class TaskManager(abc.ABC):
update_config
:
dict
=
None
)
->
ConfigurableTask
:
def
load_task
(
config
,
task
,
group
=
None
):
task_object
=
ConfigurableTask
(
config
=
config
)
def
load_task
(
config
,
task
,
group
=
None
,
is_python_class
=
False
):
if
is_python_class
:
task_object
=
config
[
"class"
]()
else
:
task_object
=
ConfigurableTask
(
config
=
config
)
if
group
is
not
None
:
task_object
=
(
group
,
task_object
)
return
{
task
:
task_object
}
if
isinstance
(
name_or_config
,
str
):
if
update_config
is
not
None
:
# Process name_or_config as a dict instead
name_or_config
=
{
"task"
:
name_or_config
,
**
update_config
}
elif
self
.
_name_is_task
(
name_or_config
):
task_config
=
self
.
_get_config
(
name_or_config
)
return
load_task
(
task_config
,
task
=
name_or_config
,
group
=
parent_name
)
is_python_class
=
False
if
self
.
_name_is_python_task
(
name_or_config
):
is_python_class
=
True
return
load_task
(
task_config
,
task
=
name_or_config
,
group
=
parent_name
,
is_python_class
=
is_python_class
)
else
:
group_name
=
name_or_config
subtask_list
=
self
.
_get_tasklist
(
name_or_config
)
...
...
@@ -126,9 +144,10 @@ class TaskManager(abc.ABC):
if
self
.
_config_is_task
(
name_or_config
):
name
=
name_or_config
[
"task"
]
# If the name is registered as a group
if
self
.
_name_is_task
(
name
)
is
False
:
group_name
=
name
update_config
=
{
k
:
v
for
k
,
v
in
name_or_config
.
items
()
if
k
is
not
"task"
}
update_config
=
{
k
:
v
for
k
,
v
in
name_or_config
.
items
()
if
k
!=
"task"
}
subtask_list
=
self
.
_get_tasklist
(
name
)
if
subtask_list
==
-
1
:
subtask_list
=
self
.
_get_config
(
name
)[
"task"
]
...
...
@@ -178,7 +197,17 @@ class TaskManager(abc.ABC):
if
f
.
endswith
(
".yaml"
):
yaml_path
=
os
.
path
.
join
(
root
,
f
)
config
=
utils
.
simple_load_yaml_config
(
yaml_path
)
if
list
(
config
.
keys
())
==
[
"group"
,
"task"
]:
if
set
(
config
.
keys
())
==
set
(
PYTHON_TASK_KEYS
):
# This is a python class config
tasks_and_groups
[
config
[
"task"
]]
=
{
"type"
:
"python_task"
,
"yaml_path"
:
yaml_path
,
}
elif
set
(
config
.
keys
())
<=
set
(
GROUP_KEYS
):
print
(
"###"
)
print
(
config
[
"group"
])
print
(
config
)
print
(
"###"
)
# This is a group config
tasks_and_groups
[
config
[
"group"
]]
=
{
"type"
:
"group"
,
...
...
lm_eval/tasks/squadv2/task.py
View file @
d352a549
...
...
@@ -21,7 +21,6 @@ from packaging import version
from
lm_eval.api.task
import
Task
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.registry
import
register_task
_CITATION
=
"""
@misc{rajpurkar2018know,
...
...
@@ -47,7 +46,6 @@ def _squad_agg(key, items):
return
_squad_metric
(
predictions
=
predictions
,
references
=
references
).
get
(
key
,
0
)
# @register_task("squadv2")
class
SQuAD2
(
Task
):
VERSION
=
3
DATASET_PATH
=
"squad_v2"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment