Commit 9118d998 authored by Nathan Habib's avatar Nathan Habib
Browse files

checkout from main

parent d26aeda7
......@@ -14,27 +14,43 @@ class TaskManager:
"""
def __init__(self, verbosity="INFO", include_path: Optional[str] = None) -> None:
def __init__(
self,
verbosity="INFO",
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
) -> None:
self.verbosity = verbosity
self.include_path = include_path
self.logger = utils.eval_logger
self.logger.setLevel(getattr(logging, f"{verbosity}"))
self._task_index = self.initialize_tasks(include_path=include_path)
self._task_index = self.initialize_tasks(
include_path=include_path, include_defaults=include_defaults
)
self._all_tasks = sorted(list(self._task_index.keys()))
self.task_group_map = collections.defaultdict(list)
def initialize_tasks(self, include_path: Optional[str] = None):
def initialize_tasks(
self,
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
):
"""Creates a dictionary of tasks index.
:param include_path: str = None
An additional path to be searched for tasks
:param include_path: Union[str, List] = None
An additional path to be searched for tasks recursively.
Can provide more than one such path as a list.
:param include_defaults: bool = True
If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
:return
Dictionary of task names as key and task metadata
"""
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
if include_defaults:
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
else:
all_paths = []
if include_path is not None:
if isinstance(include_path, str):
include_path = [include_path]
......@@ -296,8 +312,13 @@ class TaskManager:
:return
Dictionary of task names as key and task metadata
"""
ignore_dirs = [
"__pycache__",
".ipynb_checkpoints",
]
tasks_and_groups = collections.defaultdict()
for root, _, file_list in os.walk(task_dir):
for root, dirs, file_list in os.walk(task_dir):
dirs[:] = [d for d in dirs if d not in ignore_dirs]
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment