Commit 5454e95d authored by Baber's avatar Baber
Browse files

refactor: use Path

parent 0a184f46
......@@ -3,6 +3,7 @@ import inspect
import logging
import os
from functools import partial
from pathlib import Path
from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils
......@@ -25,7 +26,7 @@ class TaskManager:
def __init__(
self,
verbosity: Optional[str] = None,
include_path: Optional[Union[str, List]] = None,
include_path: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
include_defaults: bool = True,
metadata: Optional[dict] = None,
) -> None:
......@@ -56,7 +57,7 @@ class TaskManager:
def initialize_tasks(
self,
include_path: Optional[Union[str, List]] = None,
include_path: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
include_defaults: bool = True,
) -> dict[str, dict]:
"""Creates a dictionary of tasks indexes.
......@@ -70,13 +71,14 @@ class TaskManager:
Dictionary of task names as key and task metadata
"""
if include_defaults:
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
all_paths = [Path(__file__).parent]
else:
all_paths = []
if include_path is not None:
if isinstance(include_path, str):
if isinstance(include_path, (str, Path)):
include_path = [include_path]
all_paths.extend(include_path)
# Convert all paths to Path objects
all_paths.extend(Path(p) for p in include_path)
task_index = {}
for task_dir in all_paths:
......@@ -495,7 +497,7 @@ class TaskManager:
def load_config(self, config: Dict):
return self._load_individual_task_or_group(config)
def _get_task_and_group(self, task_dir: str):
def _get_task_and_group(self, task_dir: Union[str, Path]):
"""Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, `group` or `tags`.
`task` refer to regular task configs, `python_task` are special
......@@ -547,19 +549,22 @@ class TaskManager:
".ipynb_checkpoints",
]
tasks_and_groups = collections.defaultdict()
for root, dirs, file_list in os.walk(task_dir):
task_dir_path = Path(task_dir)
for root, dirs, file_list in os.walk(task_dir_path):
dirs[:] = [d for d in dirs if d not in ignore_dirs]
root_path = Path(root)
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
config = utils.load_yaml_config(yaml_path, mode="simple")
yaml_path = root_path / f
config = utils.load_yaml_config(str(yaml_path), mode="simple")
if self._config_is_python_task(config):
# This is a python class config
task = config["task"]
self._register_task(
task,
"python_task",
yaml_path,
str(yaml_path),
tasks_and_groups,
config,
_populate_tags_and_groups,
......@@ -573,7 +578,7 @@ class TaskManager:
# the task list for indexing
# as it can be loaded
# when called.
"yaml_path": yaml_path,
"yaml_path": str(yaml_path),
}
# # Registered the level 1 tasks from a group config
......@@ -591,7 +596,7 @@ class TaskManager:
self._register_task(
task,
"task",
yaml_path,
str(yaml_path),
tasks_and_groups,
config,
_populate_tags_and_groups,
......@@ -604,7 +609,7 @@ class TaskManager:
self._register_task(
task_name,
"task",
yaml_path,
str(yaml_path),
tasks_and_groups,
config,
_populate_tags_and_groups,
......
......@@ -467,14 +467,24 @@ def import_function(loader: yaml.Loader, node, yaml_path: Path):
return function
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"):
def load_yaml_config(
yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"
) -> dict:
# Convert yaml_path to Path object if it's a string
if yaml_path is not None:
yaml_path = Path(yaml_path)
# Convert yaml_dir to Path object if it's a string
if yaml_dir is not None:
yaml_dir = Path(yaml_dir)
if mode == "simple":
constructor_fn = ignore_constructor
elif mode == "full":
if yaml_path is None:
raise ValueError("yaml_path must be provided if mode is 'full'.")
# Attach yaml_path to the import function so that it can be used later
constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path))
constructor_fn = functools.partial(import_function, yaml_path=yaml_path)
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
# Add the import_function constructor to the YAML loader
......@@ -483,8 +493,8 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
with open(yaml_path, "rb") as file:
yaml_config = yaml.load(file, Loader=loader)
if yaml_dir is None:
yaml_dir = os.path.dirname(yaml_path)
if yaml_dir is None and yaml_path is not None:
yaml_dir = yaml_path.parent
assert yaml_dir is not None
......@@ -499,11 +509,13 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Convert to Path object
path = Path(path)
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not os.path.isfile(path):
path = os.path.join(yaml_dir, path)
if not path.is_file():
path = yaml_dir / path
try:
included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment