Commit 55ea2888 authored by lintangsutawika's avatar lintangsutawika
Browse files

initialize_tasks returns list of tasks and groups

parent 89c09169
...@@ -29,6 +29,12 @@ from .scrolls.task import ( ...@@ -29,6 +29,12 @@ from .scrolls.task import (
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
def load_task_or_group(yaml_path: str) -> ConfigurableTask:
config = utils.load_yaml_config(yaml_path)
return ConfigurableTask(config=config)
def register_configurable_task(config: Dict[str, str]) -> int: def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type( SubClass = type(
config["task"] + "ConfigurableTask", config["task"] + "ConfigurableTask",
...@@ -84,8 +90,6 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) - ...@@ -84,8 +90,6 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
if task_obj is not None: if task_obj is not None:
base_config = task_obj._config.to_dict(keep_callable=True) base_config = task_obj._config.to_dict(keep_callable=True)
task_name_config["task"] = f"{group}_{task_name}" task_name_config["task"] = f"{group}_{task_name}"
# elif task_name in GROUP_REGISTRY:
task_config = utils.load_yaml_config(yaml_path, task_config) task_config = utils.load_yaml_config(yaml_path, task_config)
var_configs = check_prompt_config( var_configs = check_prompt_config(
...@@ -219,19 +223,40 @@ def include_task_folder(task_dir: str, register_task: bool = True, task_name: st ...@@ -219,19 +223,40 @@ def include_task_folder(task_dir: str, register_task: bool = True, task_name: st
return 0 return 0
def include_path(task_dir, task_name=None): def get_task_and_group(task_dir: str):
include_task_folder(task_dir, task_name=task_name) task_list = {}
# Register Benchmarks after all tasks have been added group_list = {}
include_task_folder(task_dir, register_task=False, task_name=task_name) for root, _, file_list in os.walk(task_dir):
return 0 for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
def initialize_tasks(verbosity="INFO", task_name=None): config = utils.simple_load_yaml_config(yaml_path)
if "task" in config and isinstance(config["task"], str):
task_list[config["task"]] = yaml_path
if "group" in config:
if isinstance(config["group"], str):
group_list[config["group"]] = yaml_path
elif isinstance(config["group"], list):
for group in config["group"]:
group_list[group] = yaml_path
return task_list, group_list
def initialize_tasks(verbosity="INFO", include_path=None):
eval_logger.setLevel(getattr(logging, f"{verbosity}")) eval_logger.setLevel(getattr(logging, f"{verbosity}"))
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" if include_path is not None:
include_path(task_dir, task_name=task_name) if isinstance(include_path, str):
include_path = [include_path]
all_paths.extend(include_path)
task_list = {}
group_list = {}
for task_dir in all_paths:
tasks, groups = get_task_and_group(task_dir)
task_list = {**tasks, **task_list}
group_list = {**groups, **group_list}
return task_list, group_list
def get_task(task_name, config): def get_task(task_name, config):
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment