Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e5811879
Commit
e5811879
authored
Jul 02, 2024
by
haileyschoelkopf
Browse files
Python tasks which subclass ConfigurableTask now run
parent
f2e518ab
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
17 additions
and
13 deletions
+17
-13
lm_eval/evaluator_utils.py
lm_eval/evaluator_utils.py
+4
-4
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+10
-6
lm_eval/tasks/fda/task.py
lm_eval/tasks/fda/task.py
+1
-1
lm_eval/tasks/squad_completion/task.py
lm_eval/tasks/squad_completion/task.py
+1
-1
lm_eval/tasks/swde/task.py
lm_eval/tasks/swde/task.py
+1
-1
No files found.
lm_eval/evaluator_utils.py
View file @
e5811879
...
@@ -9,7 +9,7 @@ from lm_eval.api.metrics import (
...
@@ -9,7 +9,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr
,
pooled_sample_stderr
,
stderr_for_metric
,
stderr_for_metric
,
)
)
from
lm_eval.api.task
import
ConfigurableGroup
,
Configurable
Task
from
lm_eval.api.task
import
ConfigurableGroup
,
Task
from
lm_eval.utils
import
eval_logger
,
positional_deprecated
from
lm_eval.utils
import
eval_logger
,
positional_deprecated
...
@@ -167,7 +167,7 @@ def get_subtask_list(task_dict, task_root=None, depth=0):
...
@@ -167,7 +167,7 @@ def get_subtask_list(task_dict, task_root=None, depth=0):
if
isinstance
(
task_obj
,
ConfigurableGroup
):
if
isinstance
(
task_obj
,
ConfigurableGroup
):
# group_or_task_name = task_obj.group_name
# group_or_task_name = task_obj.group_name
group_or_task_name
=
task_obj
.
group_name
group_or_task_name
=
task_obj
.
group_name
elif
isinstance
(
task_obj
,
Configurable
Task
):
elif
isinstance
(
task_obj
,
Task
):
# group_or_task_name = task_obj.task_name
# group_or_task_name = task_obj.task_name
group_or_task_name
=
task_obj
.
task_name
group_or_task_name
=
task_obj
.
task_name
...
@@ -237,7 +237,7 @@ def prepare_print_tasks(
...
@@ -237,7 +237,7 @@ def prepare_print_tasks(
from_configurable_group
=
True
from_configurable_group
=
True
elif
isinstance
(
task_or_group_name
,
str
):
elif
isinstance
(
task_or_group_name
,
str
):
name
=
task_or_group_name
name
=
task_or_group_name
if
isinstance
(
task_or_group_obj
,
Configurable
Task
):
if
isinstance
(
task_or_group_obj
,
Task
):
# string_name = task_or_group_obj.task_name
# string_name = task_or_group_obj.task_name
name
=
task_or_group_obj
.
task_name
name
=
task_or_group_obj
.
task_name
from_configurable_group
=
False
from_configurable_group
=
False
...
@@ -378,7 +378,7 @@ def consolidate_group_results(
...
@@ -378,7 +378,7 @@ def consolidate_group_results(
else
:
else
:
group_config
=
None
group_config
=
None
if
isinstance
(
group_or_task_info
,
Configurable
Task
):
if
isinstance
(
group_or_task_info
,
Task
):
if
task_root
:
if
task_root
:
task_aggregation_list
.
setdefault
(
task_root
,
[]).
append
(
task_aggregation_list
.
setdefault
(
task_root
,
[]).
append
(
group_or_task_info
.
task_name
group_or_task_info
.
task_name
...
...
lm_eval/tasks/__init__.py
View file @
e5811879
...
@@ -151,14 +151,16 @@ class TaskManager:
...
@@ -151,14 +151,16 @@ class TaskManager:
**
config
,
**
config
,
}
}
if
self
.
_config_is_python_task
(
config
):
if
self
.
_config_is_python_task
(
config
):
task_object
=
config
[
"class"
](
config
=
config
)
task_object
=
(
config
[
"class"
](
config
=
config
)
if
isinstance
(
config
[
"class"
],
ConfigurableTask
)
else
config
[
"class"
]()
)
# very scuffed: set task name here TODO: fixme?
task_object
.
config
.
task
=
config
[
"task"
]
else
:
else
:
task_object
=
ConfigurableTask
(
config
=
config
)
task_object
=
ConfigurableTask
(
config
=
config
)
# if task != task_object.task_id:
# assert False
# task_object.task_id = task
return
{
task
:
task_object
}
return
{
task
:
task_object
}
def
_get_group_and_subtask_from_config
(
config
):
def
_get_group_and_subtask_from_config
(
config
):
...
@@ -187,7 +189,9 @@ class TaskManager:
...
@@ -187,7 +189,9 @@ class TaskManager:
if
update_config
is
not
None
:
if
update_config
is
not
None
:
# Process name_or_config as a dict instead
# Process name_or_config as a dict instead
name_or_config
=
{
"task"
:
name_or_config
,
**
update_config
}
name_or_config
=
{
"task"
:
name_or_config
,
**
update_config
}
elif
self
.
_name_is_task
(
name_or_config
):
elif
self
.
_name_is_task
(
name_or_config
)
or
self
.
_name_is_python_task
(
name_or_config
):
task_config
=
self
.
_get_config
(
name_or_config
)
task_config
=
self
.
_get_config
(
name_or_config
)
return
_load_task
(
task_config
,
task
=
name_or_config
)
return
_load_task
(
task_config
,
task
=
name_or_config
)
else
:
else
:
...
...
lm_eval/tasks/fda/task.py
View file @
e5811879
...
@@ -14,7 +14,7 @@ class FDA(ConfigurableTask):
...
@@ -14,7 +14,7 @@ class FDA(ConfigurableTask):
DATASET_PATH
=
"hazyresearch/based-fda"
DATASET_PATH
=
"hazyresearch/based-fda"
DATASET_NAME
=
"default"
DATASET_NAME
=
"default"
def
__init__
(
self
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
config
=
{
"metadata"
:
{
"version"
:
self
.
VERSION
}})
super
().
__init__
(
config
=
{
"metadata"
:
{
"version"
:
self
.
VERSION
}})
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
...
...
lm_eval/tasks/squad_completion/task.py
View file @
e5811879
...
@@ -14,7 +14,7 @@ class SQUADCompletion(ConfigurableTask):
...
@@ -14,7 +14,7 @@ class SQUADCompletion(ConfigurableTask):
DATASET_PATH
=
"hazyresearch/based-squad"
DATASET_PATH
=
"hazyresearch/based-squad"
DATASET_NAME
=
"default"
DATASET_NAME
=
"default"
def
__init__
(
self
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
config
=
{
"metadata"
:
{
"version"
:
self
.
VERSION
}})
super
().
__init__
(
config
=
{
"metadata"
:
{
"version"
:
self
.
VERSION
}})
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
...
...
lm_eval/tasks/swde/task.py
View file @
e5811879
...
@@ -12,7 +12,7 @@ class SWDE(ConfigurableTask):
...
@@ -12,7 +12,7 @@ class SWDE(ConfigurableTask):
DATASET_PATH
=
"hazyresearch/based-swde-v2"
DATASET_PATH
=
"hazyresearch/based-swde-v2"
DATASET_NAME
=
"default"
DATASET_NAME
=
"default"
def
__init__
(
self
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
config
=
{
"metadata"
:
{
"version"
:
self
.
VERSION
}})
super
().
__init__
(
config
=
{
"metadata"
:
{
"version"
:
self
.
VERSION
}})
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment