"...dcu-process-montor.git" did not exist on "71b1ba0d44ad74de8d45255e217679ed3b2cbe8a"
Commit 5454e95d authored by Baber's avatar Baber
Browse files

refactor: use Path

parent 0a184f46
...@@ -3,6 +3,7 @@ import inspect ...@@ -3,6 +3,7 @@ import inspect
import logging import logging
import os import os
from functools import partial from functools import partial
from pathlib import Path
from typing import Dict, List, Mapping, Optional, Union from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils from lm_eval import utils
...@@ -25,7 +26,7 @@ class TaskManager: ...@@ -25,7 +26,7 @@ class TaskManager:
def __init__( def __init__(
self, self,
verbosity: Optional[str] = None, verbosity: Optional[str] = None,
include_path: Optional[Union[str, List]] = None, include_path: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
include_defaults: bool = True, include_defaults: bool = True,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
) -> None: ) -> None:
...@@ -56,7 +57,7 @@ class TaskManager: ...@@ -56,7 +57,7 @@ class TaskManager:
def initialize_tasks( def initialize_tasks(
self, self,
include_path: Optional[Union[str, List]] = None, include_path: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
include_defaults: bool = True, include_defaults: bool = True,
) -> dict[str, dict]: ) -> dict[str, dict]:
"""Creates a dictionary of tasks indexes. """Creates a dictionary of tasks indexes.
...@@ -70,13 +71,14 @@ class TaskManager: ...@@ -70,13 +71,14 @@ class TaskManager:
Dictionary of task names as key and task metadata Dictionary of task names as key and task metadata
""" """
if include_defaults: if include_defaults:
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"] all_paths = [Path(__file__).parent]
else: else:
all_paths = [] all_paths = []
if include_path is not None: if include_path is not None:
if isinstance(include_path, str): if isinstance(include_path, (str, Path)):
include_path = [include_path] include_path = [include_path]
all_paths.extend(include_path) # Convert all paths to Path objects
all_paths.extend(Path(p) for p in include_path)
task_index = {} task_index = {}
for task_dir in all_paths: for task_dir in all_paths:
...@@ -495,7 +497,7 @@ class TaskManager: ...@@ -495,7 +497,7 @@ class TaskManager:
def load_config(self, config: Dict): def load_config(self, config: Dict):
return self._load_individual_task_or_group(config) return self._load_individual_task_or_group(config)
def _get_task_and_group(self, task_dir: str): def _get_task_and_group(self, task_dir: Union[str, Path]):
"""Creates a dictionary of tasks index with the following metadata, """Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, `group` or `tags`. - `type`, that can be either `task`, `python_task`, `group` or `tags`.
`task` refer to regular task configs, `python_task` are special `task` refer to regular task configs, `python_task` are special
...@@ -547,19 +549,22 @@ class TaskManager: ...@@ -547,19 +549,22 @@ class TaskManager:
".ipynb_checkpoints", ".ipynb_checkpoints",
] ]
tasks_and_groups = collections.defaultdict() tasks_and_groups = collections.defaultdict()
for root, dirs, file_list in os.walk(task_dir): task_dir_path = Path(task_dir)
for root, dirs, file_list in os.walk(task_dir_path):
dirs[:] = [d for d in dirs if d not in ignore_dirs] dirs[:] = [d for d in dirs if d not in ignore_dirs]
root_path = Path(root)
for f in file_list: for f in file_list:
if f.endswith(".yaml"): if f.endswith(".yaml"):
yaml_path = os.path.join(root, f) yaml_path = root_path / f
config = utils.load_yaml_config(yaml_path, mode="simple") config = utils.load_yaml_config(str(yaml_path), mode="simple")
if self._config_is_python_task(config): if self._config_is_python_task(config):
# This is a python class config # This is a python class config
task = config["task"] task = config["task"]
self._register_task( self._register_task(
task, task,
"python_task", "python_task",
yaml_path, str(yaml_path),
tasks_and_groups, tasks_and_groups,
config, config,
_populate_tags_and_groups, _populate_tags_and_groups,
...@@ -573,7 +578,7 @@ class TaskManager: ...@@ -573,7 +578,7 @@ class TaskManager:
# the task list for indexing # the task list for indexing
# as it can be loaded # as it can be loaded
# when called. # when called.
"yaml_path": yaml_path, "yaml_path": str(yaml_path),
} }
# # Registered the level 1 tasks from a group config # # Registered the level 1 tasks from a group config
...@@ -591,7 +596,7 @@ class TaskManager: ...@@ -591,7 +596,7 @@ class TaskManager:
self._register_task( self._register_task(
task, task,
"task", "task",
yaml_path, str(yaml_path),
tasks_and_groups, tasks_and_groups,
config, config,
_populate_tags_and_groups, _populate_tags_and_groups,
...@@ -604,7 +609,7 @@ class TaskManager: ...@@ -604,7 +609,7 @@ class TaskManager:
self._register_task( self._register_task(
task_name, task_name,
"task", "task",
yaml_path, str(yaml_path),
tasks_and_groups, tasks_and_groups,
config, config,
_populate_tags_and_groups, _populate_tags_and_groups,
......
...@@ -467,14 +467,24 @@ def import_function(loader: yaml.Loader, node, yaml_path: Path): ...@@ -467,14 +467,24 @@ def import_function(loader: yaml.Loader, node, yaml_path: Path):
return function return function
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"): def load_yaml_config(
yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"
) -> dict:
# Convert yaml_path to Path object if it's a string
if yaml_path is not None:
yaml_path = Path(yaml_path)
# Convert yaml_dir to Path object if it's a string
if yaml_dir is not None:
yaml_dir = Path(yaml_dir)
if mode == "simple": if mode == "simple":
constructor_fn = ignore_constructor constructor_fn = ignore_constructor
elif mode == "full": elif mode == "full":
if yaml_path is None: if yaml_path is None:
raise ValueError("yaml_path must be provided if mode is 'full'.") raise ValueError("yaml_path must be provided if mode is 'full'.")
# Attach yaml_path to the import function so that it can be used later # Attach yaml_path to the import function so that it can be used later
constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path)) constructor_fn = functools.partial(import_function, yaml_path=yaml_path)
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
# Add the import_function constructor to the YAML loader # Add the import_function constructor to the YAML loader
...@@ -483,8 +493,8 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full ...@@ -483,8 +493,8 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
with open(yaml_path, "rb") as file: with open(yaml_path, "rb") as file:
yaml_config = yaml.load(file, Loader=loader) yaml_config = yaml.load(file, Loader=loader)
if yaml_dir is None: if yaml_dir is None and yaml_path is not None:
yaml_dir = os.path.dirname(yaml_path) yaml_dir = yaml_path.parent
assert yaml_dir is not None assert yaml_dir is not None
...@@ -499,11 +509,13 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full ...@@ -499,11 +509,13 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
include_path.reverse() include_path.reverse()
final_yaml_config = {} final_yaml_config = {}
for path in include_path: for path in include_path:
# Convert to Path object
path = Path(path)
# Assumes that path is a full path. # Assumes that path is a full path.
# If not found, assume the included yaml # If not found, assume the included yaml
# is in the same dir as the original yaml # is in the same dir as the original yaml
if not os.path.isfile(path): if not path.is_file():
path = os.path.join(yaml_dir, path) path = yaml_dir / path
try: try:
included_yaml_config = load_yaml_config(yaml_path=path, mode=mode) included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment