Commit 28cc5b6e authored by lintangsutawika's avatar lintangsutawika
Browse files

indexing and loading are part of a task_manager object

parent 17172a26
......@@ -9,7 +9,7 @@ from typing import Union
import numpy as np
from lm_eval import evaluator, utils
from lm_eval.tasks import initialize_tasks, load_task_or_group
from lm_eval.tasks import TaskManager
from lm_eval.utils import make_table
......@@ -155,7 +155,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# initialize_tasks(args.verbosity)
ALL_TASKS = initialize_tasks(args.verbosity, include_path=args.include_path)
task_manager = TaskManager(args.verbosity, include_path=args.include_path)
if args.limit:
eval_logger.warning(
......@@ -170,7 +170,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
sys.exit()
elif args.tasks == "list":
eval_logger.info(
"Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS.keys())))
"Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks()))
)
else:
if os.path.isdir(args.tasks):
......@@ -183,7 +183,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
loaded_task_list.append(config)
else:
input_task_list = args.tasks.split(",")
loaded_task_list = utils.pattern_match(input_task_list, ALL_TASKS.keys())
loaded_task_list = utils.pattern_match(input_task_list, task_manager.all_tasks())
for task in [
task for task in input_task_list if task not in loaded_task_list
]:
......@@ -229,25 +229,11 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
eval_logger.info(f"Selected Tasks: {loaded_task_list}")
eval_logger.info("Loading selected tasks...")
all_tasks = {}
for task in loaded_task_list:
task_object = load_task_or_group(
ALL_TASKS,
task_name_or_config=task,
)
if isinstance(task, str):
task_name = task
elif isinstance(task, dict):
task_name = task["task"]
if isinstance(task_object, dict):
all_tasks = {**task_object, **all_tasks}
else:
all_tasks[task_name] = task_object
all_tasks = task_manager.load_task_or_group(loaded_task_list)
# for key, value in all_tasks.items():
# print(key, value)
# import sys; sys.exit()
for key, value in all_tasks.items():
print(key, value)
import sys; sys.exit()
results = evaluator.simple_evaluate(
model=args.model,
......
import os
import abc
import yaml
import collections
from typing import List, Union, Dict
......@@ -11,7 +12,7 @@ from lm_eval.api.registry import (
register_group,
TASK_REGISTRY,
GROUP_REGISTRY,
ALL_TASKS,
self.ALL_TASKS,
)
import logging
......@@ -35,11 +36,28 @@ def is_group(task):
return True
return False
class TaskManager(abc.ABC):
def load_task_or_group(ALL_TASKS, task_name_or_config: Union[str, dict] = None) -> ConfigurableTask:
def __init__(
self,
verbosity="INFO",
include_path=None
) -> None:
self.ALL_TASKS = initialize_tasks(
verbosity=verbosity,
include_path=include_path
)
@property
def all_tasks(self):
return sorted(self.ALL_TASKS.keys())
def _load_individual_task_or_group(self, task_name_or_config: Union[str, dict] = None) -> ConfigurableTask:
print("Loading", task_name_or_config)
if isinstance(task_name_or_config, str):
task_info = ALL_TASKS[task_name_or_config]
task_info = self.ALL_TASKS[task_name_or_config]
yaml_path = task_info["yaml_path"]
task_type = task_info["type"]
subtask_list = task_info["task"] if "task" in task_info else -1
......@@ -58,13 +76,13 @@ def load_task_or_group(ALL_TASKS, task_name_or_config: Union[str, dict] = None)
for task_or_config in subtask_list:
if isinstance(task_or_config, str):
all_subtasks[task_or_config] = (group_name, None)
task_object = load_task_or_group(ALL_TASKS, task_name_or_config=task_or_config)
task_object = self._load_individual_task_or_group(self.ALL_TASKS, task_name_or_config=task_or_config)
elif isinstance(task_or_config, dict):
if "group" in task_or_config:
all_subtasks[task_or_config["group"]] = (group_name, None)
elif "task" in task_or_config:
all_subtasks[task_or_config["task"]] = (group_name, None)
task_object = load_task_or_group(ALL_TASKS, task_name_or_config=task_or_config)
task_object = self._load_individual_task_or_group(self.ALL_TASKS, task_name_or_config=task_or_config)
if isinstance(task_object, dict):
all_subtasks = {**task_object, **all_subtasks}
......@@ -83,10 +101,10 @@ def load_task_or_group(ALL_TASKS, task_name_or_config: Union[str, dict] = None)
all_subtasks = {}
for task_or_config in subtask_list:
if isinstance(task_or_config, str):
task_object = load_task_or_group(ALL_TASKS, task_name_or_config=task_or_config)
task_object = self._load_individual_task_or_group(self.ALL_TASKS, task_name_or_config=task_or_config)
task_name = task_or_config
elif isinstance(task_or_config, dict):
task_object = load_task_or_group(ALL_TASKS, task_name_or_config=task_or_config)
task_object = self._load_individual_task_or_group(self.ALL_TASKS, task_name_or_config=task_or_config)
if isinstance(task_object, dict):
all_subtasks = {**task_object, **all_subtasks}
......@@ -97,7 +115,7 @@ def load_task_or_group(ALL_TASKS, task_name_or_config: Union[str, dict] = None)
else:
task_type = "task"
task_name = task_name_or_config["task"]
base_task_info = ALL_TASKS[task_name]
base_task_info = self.ALL_TASKS[task_name]
base_yaml_path = base_task_info["yaml_path"]
base_task_config = utils.load_yaml_config(base_yaml_path)
......@@ -108,6 +126,28 @@ def load_task_or_group(ALL_TASKS, task_name_or_config: Union[str, dict] = None)
}
)
def load_task_or_group(self, task_list: Union[str, list] = None) -> dict:
if isinstance(task_list, str):
task_list = [task_list]
all_loaded_tasks = {}
for task in task_list:
task_object = self._load_individual_task_or_group(
task_name_or_config=task,
)
if isinstance(task, str):
task_name = task
elif isinstance(task, dict):
task_name = task["task"]
if isinstance(task_object, dict):
all_loaded_tasks = {**task_object, **self.ALL_TASKS}
else:
all_loaded_tasks[task_name] = task_object
return all_loaded_tasks
def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type(
......@@ -182,16 +222,16 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
GROUP_REGISTRY[group].append(sub_group)
else:
GROUP_REGISTRY[group] = [sub_group]
ALL_TASKS.add(group)
self.ALL_TASKS.add(group)
task_names = utils.pattern_match(registered_task_or_group_list, ALL_TASKS)
task_names = utils.pattern_match(registered_task_or_group_list, self.ALL_TASKS)
for task in task_names:
if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
self.ALL_TASKS.add(group)
return 0
......@@ -345,12 +385,12 @@ def initialize_tasks(verbosity="INFO", include_path=None):
include_path = [include_path]
all_paths.extend(include_path)
ALL_TASKS = {}
self.ALL_TASKS = {}
for task_dir in all_paths:
tasks = get_task_and_group(task_dir)
ALL_TASKS = {**tasks, **ALL_TASKS}
self.ALL_TASKS = {**tasks, **self.ALL_TASKS}
return ALL_TASKS
return self.ALL_TASKS
def get_task(task_name, config):
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment