Commit 702d86aa authored by lintangsutawika's avatar lintangsutawika
Browse files

adapted initialize_tasks

parent 28cc5b6e
......@@ -12,7 +12,6 @@ from lm_eval.api.registry import (
register_group,
TASK_REGISTRY,
GROUP_REGISTRY,
self.ALL_TASKS,
)
import logging
......@@ -44,11 +43,29 @@ class TaskManager(abc.ABC):
include_path=None
) -> None:
self.ALL_TASKS = initialize_tasks(
verbosity=verbosity,
self.verbosity = verbosity
self.include_path = include_path
self.eval_logger.setLevel(getattr(logging, f"{verbosity}"))
self.ALL_TASKS = self.initialize_tasks(
include_path=include_path
)
def initialize_tasks(self, include_path=None):
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
if include_path is not None:
if isinstance(include_path, str):
include_path = [include_path]
all_paths.extend(include_path)
ALL_TASKS = {}
for task_dir in all_paths:
tasks = get_task_and_group(task_dir)
ALL_TASKS = {**tasks, **ALL_TASKS}
return ALL_TASKS
@property
def all_tasks(self):
return sorted(self.ALL_TASKS.keys())
......@@ -377,20 +394,7 @@ def get_task_and_group(task_dir: str):
return tasks_and_groups
def initialize_tasks(verbosity="INFO", include_path=None):
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
if include_path is not None:
if isinstance(include_path, str):
include_path = [include_path]
all_paths.extend(include_path)
self.ALL_TASKS = {}
for task_dir in all_paths:
tasks = get_task_and_group(task_dir)
self.ALL_TASKS = {**tasks, **self.ALL_TASKS}
return self.ALL_TASKS
def get_task(task_name, config):
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment