Commit e5811879 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

Python tasks which subclass ConfigurableTask now run

parent f2e518ab
......@@ -9,7 +9,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr,
stderr_for_metric,
)
from lm_eval.api.task import ConfigurableGroup, ConfigurableTask
from lm_eval.api.task import ConfigurableGroup, Task
from lm_eval.utils import eval_logger, positional_deprecated
......@@ -167,7 +167,7 @@ def get_subtask_list(task_dict, task_root=None, depth=0):
if isinstance(task_obj, ConfigurableGroup):
# group_or_task_name = task_obj.group_name
group_or_task_name = task_obj.group_name
elif isinstance(task_obj, ConfigurableTask):
elif isinstance(task_obj, Task):
# group_or_task_name = task_obj.task_name
group_or_task_name = task_obj.task_name
......@@ -237,7 +237,7 @@ def prepare_print_tasks(
from_configurable_group = True
elif isinstance(task_or_group_name, str):
name = task_or_group_name
if isinstance(task_or_group_obj, ConfigurableTask):
if isinstance(task_or_group_obj, Task):
# string_name = task_or_group_obj.task_name
name = task_or_group_obj.task_name
from_configurable_group = False
......@@ -378,7 +378,7 @@ def consolidate_group_results(
else:
group_config = None
if isinstance(group_or_task_info, ConfigurableTask):
if isinstance(group_or_task_info, Task):
if task_root:
task_aggregation_list.setdefault(task_root, []).append(
group_or_task_info.task_name
......
......@@ -151,14 +151,16 @@ class TaskManager:
**config,
}
if self._config_is_python_task(config):
task_object = config["class"](config=config)
task_object = (
config["class"](config=config)
if isinstance(config["class"], ConfigurableTask)
else config["class"]()
)
# very scuffed: set task name here TODO: fixme?
task_object.config.task = config["task"]
else:
task_object = ConfigurableTask(config=config)
# if task != task_object.task_id:
# assert False
# task_object.task_id = task
return {task: task_object}
def _get_group_and_subtask_from_config(config):
......@@ -187,7 +189,9 @@ class TaskManager:
if update_config is not None:
# Process name_or_config as a dict instead
name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config):
elif self._name_is_task(name_or_config) or self._name_is_python_task(
name_or_config
):
task_config = self._get_config(name_or_config)
return _load_task(task_config, task=name_or_config)
else:
......
......@@ -14,7 +14,7 @@ class FDA(ConfigurableTask):
DATASET_PATH = "hazyresearch/based-fda"
DATASET_NAME = "default"
def __init__(self):
def __init__(self, **kwargs):
super().__init__(config={"metadata": {"version": self.VERSION}})
def has_training_docs(self):
......
......@@ -14,7 +14,7 @@ class SQUADCompletion(ConfigurableTask):
DATASET_PATH = "hazyresearch/based-squad"
DATASET_NAME = "default"
def __init__(self):
def __init__(self, **kwargs):
super().__init__(config={"metadata": {"version": self.VERSION}})
def has_training_docs(self):
......
......@@ -12,7 +12,7 @@ class SWDE(ConfigurableTask):
DATASET_PATH = "hazyresearch/based-swde-v2"
DATASET_NAME = "default"
def __init__(self):
def __init__(self, **kwargs):
super().__init__(config={"metadata": {"version": self.VERSION}})
def has_training_docs(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment