Commit d352a549 authored by lintangsutawika's avatar lintangsutawika
Browse files

can load individual custom python class task

parent 671ce18a
......@@ -12,21 +12,23 @@ from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
import logging
# import python tasks
import squadv2
import scrolls
python_tasks = {
"squadv2": squadv2.task.SQuAD2,
"scrolls_quality": scrolls.task.QuALITY,
"scrolls_narrativeqa": scrolls.task.NarrativeQA,
"scrolls_contractnli": scrolls.task.ContractNLI,
"scrolls_govreport": scrolls.task.GovReport,
"scrolls_summscreenfd": scrolls.task.SummScreenFD,
"scrolls_qmsum": scrolls.task.QMSum,
}
# # import python tasks
# import squadv2.task
# import scrolls.task
# python_tasks = {
# "squadv2": squadv2.task.SQuAD2,
# "scrolls_quality": scrolls.task.QuALITY,
# "scrolls_narrativeqa": scrolls.task.NarrativeQA,
# "scrolls_contractnli": scrolls.task.ContractNLI,
# "scrolls_govreport": scrolls.task.GovReport,
# "scrolls_summscreenfd": scrolls.task.SummScreenFD,
# "scrolls_qmsum": scrolls.task.QMSum,
# }
eval_logger = utils.eval_logger
GROUP_KEYS = ["group", "task", "weight_by_size"]
PYTHON_TASK_KEYS = ["task", "class"]
class TaskManager(abc.ABC):
......@@ -43,7 +45,6 @@ class TaskManager(abc.ABC):
self.ALL_TASKS = self.initialize_tasks(
include_path=include_path
)
# + {k:v, "type":"task" for k,v in python_tasks.items()}
def initialize_tasks(self, include_path=None):
......@@ -69,15 +70,25 @@ class TaskManager(abc.ABC):
return False
def _name_is_task(self, name):
if self._name_is_registered(name) and (self.ALL_TASKS[name]["type"] == "task"):
if self._name_is_registered(name) and ("task" in self.ALL_TASKS[name]["type"]):
return True
return False
def _name_is_python_task(self, name):
if self._name_is_registered(name) and (self.ALL_TASKS[name]["type"] == "python_task"):
return True
return False
def _config_is_task(self, config):
if set(config.keys()) <= ["group", "task", "weight_by_size"]:
if set(config.keys()) <= set(GROUP_KEYS):
return False
return True
def _config_is_python_task(self, config):
if set(config.keys()) == set(PYTHON_TASK_KEYS):
return True
return False
def _get_yaml_path(self, name):
assert name in self.ALL_TASKS
return self.ALL_TASKS[name]["yaml_path"]
......@@ -98,18 +109,25 @@ class TaskManager(abc.ABC):
update_config: dict = None
) -> ConfigurableTask:
def load_task(config, task, group=None):
task_object = ConfigurableTask(config=config)
def load_task(config, task, group=None, is_python_class=False):
if is_python_class:
task_object = config["class"]()
else:
task_object = ConfigurableTask(config=config)
if group is not None:
task_object = (group, task_object)
return {task: task_object}
if isinstance(name_or_config, str):
if update_config is not None:
# Process name_or_config as a dict instead
name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config):
task_config = self._get_config(name_or_config)
return load_task(task_config, task=name_or_config, group=parent_name)
is_python_class=False
if self._name_is_python_task(name_or_config):
is_python_class=True
return load_task(task_config, task=name_or_config, group=parent_name, is_python_class=is_python_class)
else:
group_name = name_or_config
subtask_list = self._get_tasklist(name_or_config)
......@@ -126,9 +144,10 @@ class TaskManager(abc.ABC):
if self._config_is_task(name_or_config):
name = name_or_config["task"]
# If the name is registered as a group
if self._name_is_task(name) is False:
group_name = name
update_config = {k:v for k,v in name_or_config.items() if k is not "task"}
update_config = {k:v for k,v in name_or_config.items() if k != "task"}
subtask_list = self._get_tasklist(name)
if subtask_list == -1:
subtask_list = self._get_config(name)["task"]
......@@ -178,7 +197,17 @@ class TaskManager(abc.ABC):
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
config = utils.simple_load_yaml_config(yaml_path)
if list(config.keys()) == ["group", "task"]:
if set(config.keys()) == set(PYTHON_TASK_KEYS):
# This is a python class config
tasks_and_groups[config["task"]] = {
"type": "python_task",
"yaml_path": yaml_path,
}
elif set(config.keys()) <= set(GROUP_KEYS):
print("###")
print(config["group"])
print(config)
print("###")
# This is a group config
tasks_and_groups[config["group"]] = {
"type": "group",
......
......@@ -21,7 +21,6 @@ from packaging import version
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_task
_CITATION = """
@misc{rajpurkar2018know,
......@@ -47,7 +46,6 @@ def _squad_agg(key, items):
return _squad_metric(predictions=predictions, references=references).get(key, 0)
# @register_task("squadv2")
class SQuAD2(Task):
VERSION = 3
DATASET_PATH = "squad_v2"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment