Commit d352a549 authored by lintangsutawika's avatar lintangsutawika
Browse files

can load individual custom python class task

parent 671ce18a
...@@ -12,21 +12,23 @@ from lm_eval.api.task import TaskConfig, Task, ConfigurableTask ...@@ -12,21 +12,23 @@ from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
import logging import logging
# import python tasks # # import python tasks
import squadv2 # import squadv2.task
import scrolls # import scrolls.task
python_tasks = { # python_tasks = {
"squadv2": squadv2.task.SQuAD2, # "squadv2": squadv2.task.SQuAD2,
"scrolls_quality": scrolls.task.QuALITY, # "scrolls_quality": scrolls.task.QuALITY,
"scrolls_narrativeqa": scrolls.task.NarrativeQA, # "scrolls_narrativeqa": scrolls.task.NarrativeQA,
"scrolls_contractnli": scrolls.task.ContractNLI, # "scrolls_contractnli": scrolls.task.ContractNLI,
"scrolls_govreport": scrolls.task.GovReport, # "scrolls_govreport": scrolls.task.GovReport,
"scrolls_summscreenfd": scrolls.task.SummScreenFD, # "scrolls_summscreenfd": scrolls.task.SummScreenFD,
"scrolls_qmsum": scrolls.task.QMSum, # "scrolls_qmsum": scrolls.task.QMSum,
} # }
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
GROUP_KEYS = ["group", "task", "weight_by_size"]
PYTHON_TASK_KEYS = ["task", "class"]
class TaskManager(abc.ABC): class TaskManager(abc.ABC):
...@@ -43,7 +45,6 @@ class TaskManager(abc.ABC): ...@@ -43,7 +45,6 @@ class TaskManager(abc.ABC):
self.ALL_TASKS = self.initialize_tasks( self.ALL_TASKS = self.initialize_tasks(
include_path=include_path include_path=include_path
) )
# + {k:v, "type":"task" for k,v in python_tasks.items()}
def initialize_tasks(self, include_path=None): def initialize_tasks(self, include_path=None):
...@@ -69,15 +70,25 @@ class TaskManager(abc.ABC): ...@@ -69,15 +70,25 @@ class TaskManager(abc.ABC):
return False return False
def _name_is_task(self, name): def _name_is_task(self, name):
if self._name_is_registered(name) and (self.ALL_TASKS[name]["type"] == "task"): if self._name_is_registered(name) and ("task" in self.ALL_TASKS[name]["type"]):
return True
return False
def _name_is_python_task(self, name):
if self._name_is_registered(name) and (self.ALL_TASKS[name]["type"] == "python_task"):
return True return True
return False return False
def _config_is_task(self, config): def _config_is_task(self, config):
if set(config.keys()) <= ["group", "task", "weight_by_size"]: if set(config.keys()) <= set(GROUP_KEYS):
return False return False
return True return True
def _config_is_python_task(self, config):
if set(config.keys()) == set(PYTHON_TASK_KEYS):
return True
return False
def _get_yaml_path(self, name): def _get_yaml_path(self, name):
assert name in self.ALL_TASKS assert name in self.ALL_TASKS
return self.ALL_TASKS[name]["yaml_path"] return self.ALL_TASKS[name]["yaml_path"]
...@@ -98,18 +109,25 @@ class TaskManager(abc.ABC): ...@@ -98,18 +109,25 @@ class TaskManager(abc.ABC):
update_config: dict = None update_config: dict = None
) -> ConfigurableTask: ) -> ConfigurableTask:
def load_task(config, task, group=None): def load_task(config, task, group=None, is_python_class=False):
task_object = ConfigurableTask(config=config) if is_python_class:
task_object = config["class"]()
else:
task_object = ConfigurableTask(config=config)
if group is not None: if group is not None:
task_object = (group, task_object) task_object = (group, task_object)
return {task: task_object} return {task: task_object}
if isinstance(name_or_config, str): if isinstance(name_or_config, str):
if update_config is not None: if update_config is not None:
# Process name_or_config as a dict instead
name_or_config = {"task": name_or_config, **update_config} name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config): elif self._name_is_task(name_or_config):
task_config = self._get_config(name_or_config) task_config = self._get_config(name_or_config)
return load_task(task_config, task=name_or_config, group=parent_name) is_python_class=False
if self._name_is_python_task(name_or_config):
is_python_class=True
return load_task(task_config, task=name_or_config, group=parent_name, is_python_class=is_python_class)
else: else:
group_name = name_or_config group_name = name_or_config
subtask_list = self._get_tasklist(name_or_config) subtask_list = self._get_tasklist(name_or_config)
...@@ -126,9 +144,10 @@ class TaskManager(abc.ABC): ...@@ -126,9 +144,10 @@ class TaskManager(abc.ABC):
if self._config_is_task(name_or_config): if self._config_is_task(name_or_config):
name = name_or_config["task"] name = name_or_config["task"]
# If the name is registered as a group
if self._name_is_task(name) is False: if self._name_is_task(name) is False:
group_name = name group_name = name
update_config = {k:v for k,v in name_or_config.items() if k is not "task"} update_config = {k:v for k,v in name_or_config.items() if k != "task"}
subtask_list = self._get_tasklist(name) subtask_list = self._get_tasklist(name)
if subtask_list == -1: if subtask_list == -1:
subtask_list = self._get_config(name)["task"] subtask_list = self._get_config(name)["task"]
...@@ -178,7 +197,17 @@ class TaskManager(abc.ABC): ...@@ -178,7 +197,17 @@ class TaskManager(abc.ABC):
if f.endswith(".yaml"): if f.endswith(".yaml"):
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
config = utils.simple_load_yaml_config(yaml_path) config = utils.simple_load_yaml_config(yaml_path)
if list(config.keys()) == ["group", "task"]: if set(config.keys()) == set(PYTHON_TASK_KEYS):
# This is a python class config
tasks_and_groups[config["task"]] = {
"type": "python_task",
"yaml_path": yaml_path,
}
elif set(config.keys()) <= set(GROUP_KEYS):
print("###")
print(config["group"])
print(config)
print("###")
# This is a group config # This is a group config
tasks_and_groups[config["group"]] = { tasks_and_groups[config["group"]] = {
"type": "group", "type": "group",
......
...@@ -21,7 +21,6 @@ from packaging import version ...@@ -21,7 +21,6 @@ from packaging import version
from lm_eval.api.task import Task from lm_eval.api.task import Task
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_task
_CITATION = """ _CITATION = """
@misc{rajpurkar2018know, @misc{rajpurkar2018know,
...@@ -47,7 +46,6 @@ def _squad_agg(key, items): ...@@ -47,7 +46,6 @@ def _squad_agg(key, items):
return _squad_metric(predictions=predictions, references=references).get(key, 0) return _squad_metric(predictions=predictions, references=references).get(key, 0)
# @register_task("squadv2")
class SQuAD2(Task): class SQuAD2(Task):
VERSION = 3 VERSION = 3
DATASET_PATH = "squad_v2" DATASET_PATH = "squad_v2"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment