Commit 68b3cddc authored by Baber's avatar Baber
Browse files

nit

parent 495ea3a0
......@@ -11,9 +11,7 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Mapping,
Optional,
Union,
......@@ -30,15 +28,21 @@ from lm_eval.utils import pattern_match, setup_logging
if TYPE_CHECKING:
from lm_eval.api.task import ConfigurableTask, Task
eval_logger = logging.getLogger(__name__)
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
_IGNORE_DIRS = (
"__pycache__",
".ipynb_checkpoints",
)
eval_logger = logging.getLogger(__name__)
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
return None
@functools.lru_cache(maxsize=None) # ← reuse per (directory, simple) pair
@functools.lru_cache(maxsize=2048) # ← reuse per (directory, simple) pair
def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
"""
Return a custom YAML Loader class bound to *yaml_dir*.
......@@ -48,14 +52,13 @@ def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
We capture it so that !function look-ups can resolve relative
Python files like my_utils.some_fn ➜ yaml_dir / "my_utils.py".
simple
If True we ignore !function completely (used by `mode="simple"`).
If True we ignore !function completely (used by `mode="simple"`),
used on TaskManager init to index.
"""
class Loader(_Base):
"""Dynamically-generated loader that knows its base directory."""
# no extra state needed; the constructor stays the same
# Register (or stub) the !function constructor **for this Loader only**
if simple:
yaml.add_constructor("!function", ignore_constructor, Loader=Loader)
......@@ -90,18 +93,14 @@ def _import_function(qualname: str, *, base_path: Path) -> Callable:
return getattr(mod, func_name)
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
return None
@functools.lru_cache(maxsize=None) #
@functools.lru_cache(maxsize=4096) #
def _parse_yaml_file(path: Path, mode: str) -> dict:
loader_cls = _make_loader(path.parent, simple=(mode == "simple"))
with path.open("rb") as fh:
return yaml.load(fh, Loader=loader_cls)
@functools.lru_cache(maxsize=None)
@functools.lru_cache(maxsize=4096)
def _get_cached_config(yaml_path: Path, mode: str) -> dict:
"""Load and cache resolved YAML configs with LRU eviction."""
# Parse the YAML file
......@@ -212,10 +211,9 @@ def load_yaml_config(
def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
# '**/*.yaml' is handled internally by os.scandir.
for p in iglob("**/*.yaml", root_dir=root, recursive=True):
# ignore check
if p.startswith(("__pycache__", ".ipynb_checkpoints")):
if Path(p).parts[0] in _IGNORE_DIRS:
continue
yield root / p
......@@ -229,7 +227,7 @@ class TaskManager:
def __init__(
self,
verbosity: Optional[str] = None,
include_path: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
include_defaults: bool = True,
metadata: Optional[dict] = None,
) -> None:
......@@ -260,18 +258,18 @@ class TaskManager:
def initialize_tasks(
self,
include_path: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
include_defaults: bool = True,
) -> dict[str, dict]:
"""Creates a dictionary of tasks indexes.
:param include_path: Union[str, List] = None
:param include_path: Union[str, list] = None
An additional path to be searched for tasks recursively.
Can provide more than one such path as a list.
:param include_defaults: bool = True
If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
return
Dictionary of task names as key and task metadata
dictionary of task names as key and task metadata
"""
if include_defaults:
all_paths = [Path(__file__).parent]
......@@ -291,23 +289,23 @@ class TaskManager:
return task_index
@property
def all_tasks(self) -> List[str]:
def all_tasks(self) -> list[str]:
return self._all_tasks
@property
def all_groups(self) -> List[str]:
def all_groups(self) -> list[str]:
return self._all_groups
@property
def all_subtasks(self) -> List[str]:
def all_subtasks(self) -> list[str]:
return self._all_subtasks
@property
def all_tags(self) -> List[str]:
def all_tags(self) -> list[str]:
return self._all_tags
@property
def task_index(self) -> Dict[str, Dict[str, Union[str, int, List[str]]]]:
def task_index(self) -> dict[str, dict[str, Union[str, int, list[str]]]]:
return self._task_index
def list_all_tasks(
......@@ -340,15 +338,16 @@ class TaskManager:
inc_list = inc_raw if isinstance(inc_raw, list) else [inc_raw]
for inc in inc_list:
inc_path = Path(inc)
if not inc_path.is_absolute(): # treat as relative include
inc_path = base.parent / inc_path
try:
inc_cfg = load_yaml_config(inc_path, mode="simple")
except FileNotFoundError:
continue
if "output_type" in inc_cfg:
return inc_cfg["output_type"]
if inc:
inc_path = Path(inc)
if not inc_path.is_absolute(): # treat as relative include
inc_path = base.parent / inc_path
try:
inc_cfg = load_yaml_config(inc_path, mode="simple")
except FileNotFoundError:
continue
if "output_type" in inc_cfg:
return inc_cfg["output_type"]
return ""
# -------------------------------------------------------------- GROUP table
......@@ -452,7 +451,7 @@ class TaskManager:
raise ValueError
return self.task_index[name]["yaml_path"]
def _get_config(self, name: str) -> Dict:
def _get_config(self, name: str) -> dict:
if name not in self.task_index:
raise ValueError
yaml_path = self._get_yaml_path(name)
......@@ -461,7 +460,7 @@ class TaskManager:
else:
return load_yaml_config(Path(yaml_path), mode="full")
def _get_tasklist(self, name: str) -> Union[List[str], int]:
def _get_tasklist(self, name: str) -> Union[list[str], int]:
if self._name_is_task(name):
raise ValueError
return self.task_index[name]["task"]
......@@ -471,8 +470,8 @@ class TaskManager:
task_name: str,
task_type: str,
yaml_path: str,
tasks_and_groups: Dict[str, Dict],
config: Optional[Dict] = None,
tasks_and_groups: dict[str, dict],
config: Optional[dict] = None,
populate_tags_fn: Optional[callable] = None,
) -> None:
"""Helper method to register a task in the tasks_and_groups dict"""
......@@ -485,8 +484,8 @@ class TaskManager:
populate_tags_fn(config, task_name, tasks_and_groups)
def _merge_task_configs(
self, base_config: Dict, task_specific_config: Dict, task_name: str
) -> Dict:
self, base_config: dict, task_specific_config: dict, task_name: str
) -> dict:
"""Merge base config with task-specific overrides for task_list configs"""
if task_specific_config:
task_specific_config = task_specific_config.copy()
......@@ -495,8 +494,8 @@ class TaskManager:
return {**base_config, "task": task_name}
def _process_tag_subtasks(
self, tag_name: str, update_config: Optional[Dict] = None
) -> Dict:
self, tag_name: str, update_config: Optional[dict] = None
) -> dict:
"""Process subtasks for a tag and return loaded tasks"""
subtask_list = self._get_tasklist(tag_name)
fn = partial(
......@@ -505,7 +504,7 @@ class TaskManager:
)
return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
def _process_alias(self, config: Dict, group: Optional[str] = None) -> Dict:
def _process_alias(self, config: dict, group: Optional[str] = None) -> dict:
# If the group is not the same as the original
# group which the group alias was intended for,
# Set the group_alias to None instead.
......@@ -524,15 +523,15 @@ class TaskManager:
def _load_individual_task_or_group(
self,
name_or_config: Optional[Union[str, Dict]] = None,
name_or_config: Optional[Union[str, dict]] = None,
parent_name: Optional[str] = None,
update_config: Optional[Dict] = None,
update_config: Optional[dict] = None,
) -> Mapping:
from lm_eval.api.task import ConfigurableTask, Task
def _load_task(
config: Dict, task: str, yaml_path: Optional[str] = None
) -> Dict[str, Union["ConfigurableTask", "Task"]]:
config: dict, task: str, yaml_path: Optional[str] = None
) -> dict[str, Union["ConfigurableTask", "Task"]]:
if "include" in config:
# Store the task name to preserve it after include processing
original_task_name = config.get("task", task)
......@@ -568,8 +567,8 @@ class TaskManager:
return {task: task_object}
def _get_group_and_subtask_from_config(
config: Dict,
) -> tuple[ConfigurableGroup, List[str]]:
config: dict,
) -> tuple[ConfigurableGroup, list[str]]:
if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | self.metadata
group_name = ConfigurableGroup(config=config)
......@@ -582,8 +581,8 @@ class TaskManager:
return group_name, subtask_list
def _process_group_config(
config: Dict, update_config: Optional[Dict] = None
) -> tuple[Dict, Optional[Dict]]:
config: dict, update_config: Optional[dict] = None
) -> tuple[dict, Optional[dict]]:
if update_config is not None:
config = {**config, **update_config}
_update_config = {
......@@ -716,15 +715,15 @@ class TaskManager:
}
def load_task_or_group(
self, task_list: Optional[Union[str, List[str]]] = None
) -> Dict:
self, task_list: Optional[Union[str, list[str]]] = None
) -> dict:
"""Loads a dictionary of task objects from a list
:param task_list: Union[str, list] = None
Single string or list of string of task names to be loaded
:return
Dictionary of task objects
dictionary of task objects
"""
if isinstance(task_list, str):
task_list = [task_list]
......@@ -739,10 +738,10 @@ class TaskManager:
)
return all_loaded_tasks
def load_config(self, config: Dict) -> Mapping:
def load_config(self, config: dict) -> Mapping:
return self._load_individual_task_or_group(config)
def _get_task_and_group(self, task_dir: Union[str, Path]) -> Dict[str, Dict]:
def _get_task_and_group(self, task_dir: Union[str, Path]) -> dict[str, dict]:
"""Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, `group` or `tags`.
`task` refer to regular task configs, `python_task` are special
......@@ -762,11 +761,11 @@ class TaskManager:
A directory to check for tasks
:return
Dictionary of task names as key and task metadata
dictionary of task names as key and task metadata
"""
def _populate_tags_and_groups(
config: Dict, task: str, tasks_and_groups: Dict[str, Dict]
config: dict, task: str, tasks_and_groups: dict[str, dict]
) -> None:
# TODO: remove group in next release
if "tag" in config:
......@@ -868,7 +867,7 @@ class TaskManager:
return tasks_and_groups
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
def get_task_name_from_config(task_config: dict[str, str]) -> str:
if "task" in task_config:
return task_config["task"]
if "dataset_name" in task_config:
......@@ -890,7 +889,7 @@ def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) ->
)
def _check_duplicates(task_dict: Dict[str, List[str]]) -> None:
def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
"""helper function solely used in validating get_task_dict output.
Takes the output of lm_eval.evaluator_utils.get_subtask_list and
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
......@@ -918,12 +917,12 @@ def _check_duplicates(task_dict: Dict[str, List[str]]) -> None:
def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, "Task"]]],
task_name_list: Union[str, list[Union[str, dict, "Task"]]],
task_manager: Optional[TaskManager] = None,
) -> Dict[str, Union["ConfigurableTask", "Task"]]:
) -> dict[str, Union["ConfigurableTask", "Task"]]:
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]]
:param task_name_list: list[Union[str, dict, Task]]
Name of model or LM object, see lm_eval.models.get_model
:param task_manager: TaskManager = None
A TaskManager object that stores indexed tasks. If not set,
......@@ -932,7 +931,7 @@ def get_task_dict(
via `include_path`
:return
Dictionary of task objects
dictionary of task objects
"""
from lm_eval.api.task import ConfigurableTask, Task
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment