Commit 68b3cddc authored by Baber's avatar Baber
Browse files

nit

parent 495ea3a0
...@@ -11,9 +11,7 @@ from typing import ( ...@@ -11,9 +11,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict,
Generator, Generator,
List,
Mapping, Mapping,
Optional, Optional,
Union, Union,
...@@ -30,15 +28,21 @@ from lm_eval.utils import pattern_match, setup_logging ...@@ -30,15 +28,21 @@ from lm_eval.utils import pattern_match, setup_logging
if TYPE_CHECKING: if TYPE_CHECKING:
from lm_eval.api.task import ConfigurableTask, Task from lm_eval.api.task import ConfigurableTask, Task
eval_logger = logging.getLogger(__name__)
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys()) GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
_IGNORE_DIRS = (
"__pycache__",
".ipynb_checkpoints",
)
eval_logger = logging.getLogger(__name__) def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader return None
@functools.lru_cache(maxsize=None) # ← reuse per (directory, simple) pair @functools.lru_cache(maxsize=2048) # ← reuse per (directory, simple) pair
def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]: def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
""" """
Return a custom YAML Loader class bound to *yaml_dir*. Return a custom YAML Loader class bound to *yaml_dir*.
...@@ -48,14 +52,13 @@ def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]: ...@@ -48,14 +52,13 @@ def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
We capture it so that !function look-ups can resolve relative We capture it so that !function look-ups can resolve relative
Python files like my_utils.some_fn ➜ yaml_dir / "my_utils.py". Python files like my_utils.some_fn ➜ yaml_dir / "my_utils.py".
simple simple
If True we ignore !function completely (used by `mode="simple"`). If True we ignore !function completely (used by `mode="simple"`),
used on TaskManager init to index.
""" """
class Loader(_Base): class Loader(_Base):
"""Dynamically-generated loader that knows its base directory.""" """Dynamically-generated loader that knows its base directory."""
# no extra state needed; the constructor stays the same
# Register (or stub) the !function constructor **for this Loader only** # Register (or stub) the !function constructor **for this Loader only**
if simple: if simple:
yaml.add_constructor("!function", ignore_constructor, Loader=Loader) yaml.add_constructor("!function", ignore_constructor, Loader=Loader)
...@@ -90,18 +93,14 @@ def _import_function(qualname: str, *, base_path: Path) -> Callable: ...@@ -90,18 +93,14 @@ def _import_function(qualname: str, *, base_path: Path) -> Callable:
return getattr(mod, func_name) return getattr(mod, func_name)
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None: @functools.lru_cache(maxsize=4096) #
return None
@functools.lru_cache(maxsize=None) #
def _parse_yaml_file(path: Path, mode: str) -> dict: def _parse_yaml_file(path: Path, mode: str) -> dict:
loader_cls = _make_loader(path.parent, simple=(mode == "simple")) loader_cls = _make_loader(path.parent, simple=(mode == "simple"))
with path.open("rb") as fh: with path.open("rb") as fh:
return yaml.load(fh, Loader=loader_cls) return yaml.load(fh, Loader=loader_cls)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=4096)
def _get_cached_config(yaml_path: Path, mode: str) -> dict: def _get_cached_config(yaml_path: Path, mode: str) -> dict:
"""Load and cache resolved YAML configs with LRU eviction.""" """Load and cache resolved YAML configs with LRU eviction."""
# Parse the YAML file # Parse the YAML file
...@@ -212,10 +211,9 @@ def load_yaml_config( ...@@ -212,10 +211,9 @@ def load_yaml_config(
def iter_yaml_files(root: Path) -> Generator[Path, Any, None]: def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
# '**/*.yaml' is handled internally by os.scandir.
for p in iglob("**/*.yaml", root_dir=root, recursive=True): for p in iglob("**/*.yaml", root_dir=root, recursive=True):
# ignore check # ignore check
if p.startswith(("__pycache__", ".ipynb_checkpoints")): if Path(p).parts[0] in _IGNORE_DIRS:
continue continue
yield root / p yield root / p
...@@ -229,7 +227,7 @@ class TaskManager: ...@@ -229,7 +227,7 @@ class TaskManager:
def __init__( def __init__(
self, self,
verbosity: Optional[str] = None, verbosity: Optional[str] = None,
include_path: Optional[Union[str, Path, List[Union[str, Path]]]] = None, include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
include_defaults: bool = True, include_defaults: bool = True,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
) -> None: ) -> None:
...@@ -260,18 +258,18 @@ class TaskManager: ...@@ -260,18 +258,18 @@ class TaskManager:
def initialize_tasks( def initialize_tasks(
self, self,
include_path: Optional[Union[str, Path, List[Union[str, Path]]]] = None, include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
include_defaults: bool = True, include_defaults: bool = True,
) -> dict[str, dict]: ) -> dict[str, dict]:
"""Creates a dictionary of tasks indexes. """Creates a dictionary of tasks indexes.
:param include_path: Union[str, List] = None :param include_path: Union[str, list] = None
An additional path to be searched for tasks recursively. An additional path to be searched for tasks recursively.
Can provide more than one such path as a list. Can provide more than one such path as a list.
:param include_defaults: bool = True :param include_defaults: bool = True
If set to false, default tasks (those in lm_eval/tasks/) are not indexed. If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
return return
Dictionary of task names as key and task metadata dictionary of task names as key and task metadata
""" """
if include_defaults: if include_defaults:
all_paths = [Path(__file__).parent] all_paths = [Path(__file__).parent]
...@@ -291,23 +289,23 @@ class TaskManager: ...@@ -291,23 +289,23 @@ class TaskManager:
return task_index return task_index
@property @property
def all_tasks(self) -> List[str]: def all_tasks(self) -> list[str]:
return self._all_tasks return self._all_tasks
@property @property
def all_groups(self) -> List[str]: def all_groups(self) -> list[str]:
return self._all_groups return self._all_groups
@property @property
def all_subtasks(self) -> List[str]: def all_subtasks(self) -> list[str]:
return self._all_subtasks return self._all_subtasks
@property @property
def all_tags(self) -> List[str]: def all_tags(self) -> list[str]:
return self._all_tags return self._all_tags
@property @property
def task_index(self) -> Dict[str, Dict[str, Union[str, int, List[str]]]]: def task_index(self) -> dict[str, dict[str, Union[str, int, list[str]]]]:
return self._task_index return self._task_index
def list_all_tasks( def list_all_tasks(
...@@ -340,15 +338,16 @@ class TaskManager: ...@@ -340,15 +338,16 @@ class TaskManager:
inc_list = inc_raw if isinstance(inc_raw, list) else [inc_raw] inc_list = inc_raw if isinstance(inc_raw, list) else [inc_raw]
for inc in inc_list: for inc in inc_list:
inc_path = Path(inc) if inc:
if not inc_path.is_absolute(): # treat as relative include inc_path = Path(inc)
inc_path = base.parent / inc_path if not inc_path.is_absolute(): # treat as relative include
try: inc_path = base.parent / inc_path
inc_cfg = load_yaml_config(inc_path, mode="simple") try:
except FileNotFoundError: inc_cfg = load_yaml_config(inc_path, mode="simple")
continue except FileNotFoundError:
if "output_type" in inc_cfg: continue
return inc_cfg["output_type"] if "output_type" in inc_cfg:
return inc_cfg["output_type"]
return "" return ""
# -------------------------------------------------------------- GROUP table # -------------------------------------------------------------- GROUP table
...@@ -452,7 +451,7 @@ class TaskManager: ...@@ -452,7 +451,7 @@ class TaskManager:
raise ValueError raise ValueError
return self.task_index[name]["yaml_path"] return self.task_index[name]["yaml_path"]
def _get_config(self, name: str) -> Dict: def _get_config(self, name: str) -> dict:
if name not in self.task_index: if name not in self.task_index:
raise ValueError raise ValueError
yaml_path = self._get_yaml_path(name) yaml_path = self._get_yaml_path(name)
...@@ -461,7 +460,7 @@ class TaskManager: ...@@ -461,7 +460,7 @@ class TaskManager:
else: else:
return load_yaml_config(Path(yaml_path), mode="full") return load_yaml_config(Path(yaml_path), mode="full")
def _get_tasklist(self, name: str) -> Union[List[str], int]: def _get_tasklist(self, name: str) -> Union[list[str], int]:
if self._name_is_task(name): if self._name_is_task(name):
raise ValueError raise ValueError
return self.task_index[name]["task"] return self.task_index[name]["task"]
...@@ -471,8 +470,8 @@ class TaskManager: ...@@ -471,8 +470,8 @@ class TaskManager:
task_name: str, task_name: str,
task_type: str, task_type: str,
yaml_path: str, yaml_path: str,
tasks_and_groups: Dict[str, Dict], tasks_and_groups: dict[str, dict],
config: Optional[Dict] = None, config: Optional[dict] = None,
populate_tags_fn: Optional[callable] = None, populate_tags_fn: Optional[callable] = None,
) -> None: ) -> None:
"""Helper method to register a task in the tasks_and_groups dict""" """Helper method to register a task in the tasks_and_groups dict"""
...@@ -485,8 +484,8 @@ class TaskManager: ...@@ -485,8 +484,8 @@ class TaskManager:
populate_tags_fn(config, task_name, tasks_and_groups) populate_tags_fn(config, task_name, tasks_and_groups)
def _merge_task_configs( def _merge_task_configs(
self, base_config: Dict, task_specific_config: Dict, task_name: str self, base_config: dict, task_specific_config: dict, task_name: str
) -> Dict: ) -> dict:
"""Merge base config with task-specific overrides for task_list configs""" """Merge base config with task-specific overrides for task_list configs"""
if task_specific_config: if task_specific_config:
task_specific_config = task_specific_config.copy() task_specific_config = task_specific_config.copy()
...@@ -495,8 +494,8 @@ class TaskManager: ...@@ -495,8 +494,8 @@ class TaskManager:
return {**base_config, "task": task_name} return {**base_config, "task": task_name}
def _process_tag_subtasks( def _process_tag_subtasks(
self, tag_name: str, update_config: Optional[Dict] = None self, tag_name: str, update_config: Optional[dict] = None
) -> Dict: ) -> dict:
"""Process subtasks for a tag and return loaded tasks""" """Process subtasks for a tag and return loaded tasks"""
subtask_list = self._get_tasklist(tag_name) subtask_list = self._get_tasklist(tag_name)
fn = partial( fn = partial(
...@@ -505,7 +504,7 @@ class TaskManager: ...@@ -505,7 +504,7 @@ class TaskManager:
) )
return dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
def _process_alias(self, config: Dict, group: Optional[str] = None) -> Dict: def _process_alias(self, config: dict, group: Optional[str] = None) -> dict:
# If the group is not the same as the original # If the group is not the same as the original
# group which the group alias was intended for, # group which the group alias was intended for,
# Set the group_alias to None instead. # Set the group_alias to None instead.
...@@ -524,15 +523,15 @@ class TaskManager: ...@@ -524,15 +523,15 @@ class TaskManager:
def _load_individual_task_or_group( def _load_individual_task_or_group(
self, self,
name_or_config: Optional[Union[str, Dict]] = None, name_or_config: Optional[Union[str, dict]] = None,
parent_name: Optional[str] = None, parent_name: Optional[str] = None,
update_config: Optional[Dict] = None, update_config: Optional[dict] = None,
) -> Mapping: ) -> Mapping:
from lm_eval.api.task import ConfigurableTask, Task from lm_eval.api.task import ConfigurableTask, Task
def _load_task( def _load_task(
config: Dict, task: str, yaml_path: Optional[str] = None config: dict, task: str, yaml_path: Optional[str] = None
) -> Dict[str, Union["ConfigurableTask", "Task"]]: ) -> dict[str, Union["ConfigurableTask", "Task"]]:
if "include" in config: if "include" in config:
# Store the task name to preserve it after include processing # Store the task name to preserve it after include processing
original_task_name = config.get("task", task) original_task_name = config.get("task", task)
...@@ -568,8 +567,8 @@ class TaskManager: ...@@ -568,8 +567,8 @@ class TaskManager:
return {task: task_object} return {task: task_object}
def _get_group_and_subtask_from_config( def _get_group_and_subtask_from_config(
config: Dict, config: dict,
) -> tuple[ConfigurableGroup, List[str]]: ) -> tuple[ConfigurableGroup, list[str]]:
if self.metadata is not None: if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | self.metadata config["metadata"] = config.get("metadata", {}) | self.metadata
group_name = ConfigurableGroup(config=config) group_name = ConfigurableGroup(config=config)
...@@ -582,8 +581,8 @@ class TaskManager: ...@@ -582,8 +581,8 @@ class TaskManager:
return group_name, subtask_list return group_name, subtask_list
def _process_group_config( def _process_group_config(
config: Dict, update_config: Optional[Dict] = None config: dict, update_config: Optional[dict] = None
) -> tuple[Dict, Optional[Dict]]: ) -> tuple[dict, Optional[dict]]:
if update_config is not None: if update_config is not None:
config = {**config, **update_config} config = {**config, **update_config}
_update_config = { _update_config = {
...@@ -716,15 +715,15 @@ class TaskManager: ...@@ -716,15 +715,15 @@ class TaskManager:
} }
def load_task_or_group( def load_task_or_group(
self, task_list: Optional[Union[str, List[str]]] = None self, task_list: Optional[Union[str, list[str]]] = None
) -> Dict: ) -> dict:
"""Loads a dictionary of task objects from a list """Loads a dictionary of task objects from a list
:param task_list: Union[str, list] = None :param task_list: Union[str, list] = None
Single string or list of string of task names to be loaded Single string or list of string of task names to be loaded
:return :return
Dictionary of task objects dictionary of task objects
""" """
if isinstance(task_list, str): if isinstance(task_list, str):
task_list = [task_list] task_list = [task_list]
...@@ -739,10 +738,10 @@ class TaskManager: ...@@ -739,10 +738,10 @@ class TaskManager:
) )
return all_loaded_tasks return all_loaded_tasks
def load_config(self, config: Dict) -> Mapping: def load_config(self, config: dict) -> Mapping:
return self._load_individual_task_or_group(config) return self._load_individual_task_or_group(config)
def _get_task_and_group(self, task_dir: Union[str, Path]) -> Dict[str, Dict]: def _get_task_and_group(self, task_dir: Union[str, Path]) -> dict[str, dict]:
"""Creates a dictionary of tasks index with the following metadata, """Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, `group` or `tags`. - `type`, that can be either `task`, `python_task`, `group` or `tags`.
`task` refer to regular task configs, `python_task` are special `task` refer to regular task configs, `python_task` are special
...@@ -762,11 +761,11 @@ class TaskManager: ...@@ -762,11 +761,11 @@ class TaskManager:
A directory to check for tasks A directory to check for tasks
:return :return
Dictionary of task names as key and task metadata dictionary of task names as key and task metadata
""" """
def _populate_tags_and_groups( def _populate_tags_and_groups(
config: Dict, task: str, tasks_and_groups: Dict[str, Dict] config: dict, task: str, tasks_and_groups: dict[str, dict]
) -> None: ) -> None:
# TODO: remove group in next release # TODO: remove group in next release
if "tag" in config: if "tag" in config:
...@@ -868,7 +867,7 @@ class TaskManager: ...@@ -868,7 +867,7 @@ class TaskManager:
return tasks_and_groups return tasks_and_groups
def get_task_name_from_config(task_config: Dict[str, str]) -> str: def get_task_name_from_config(task_config: dict[str, str]) -> str:
if "task" in task_config: if "task" in task_config:
return task_config["task"] return task_config["task"]
if "dataset_name" in task_config: if "dataset_name" in task_config:
...@@ -890,7 +889,7 @@ def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) -> ...@@ -890,7 +889,7 @@ def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) ->
) )
def _check_duplicates(task_dict: Dict[str, List[str]]) -> None: def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
"""helper function solely used in validating get_task_dict output. """helper function solely used in validating get_task_dict output.
Takes the output of lm_eval.evaluator_utils.get_subtask_list and Takes the output of lm_eval.evaluator_utils.get_subtask_list and
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
...@@ -918,12 +917,12 @@ def _check_duplicates(task_dict: Dict[str, List[str]]) -> None: ...@@ -918,12 +917,12 @@ def _check_duplicates(task_dict: Dict[str, List[str]]) -> None:
def get_task_dict( def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, "Task"]]], task_name_list: Union[str, list[Union[str, dict, "Task"]]],
task_manager: Optional[TaskManager] = None, task_manager: Optional[TaskManager] = None,
) -> Dict[str, Union["ConfigurableTask", "Task"]]: ) -> dict[str, Union["ConfigurableTask", "Task"]]:
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object. """Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]] :param task_name_list: list[Union[str, dict, Task]]
Name of model or LM object, see lm_eval.models.get_model Name of model or LM object, see lm_eval.models.get_model
:param task_manager: TaskManager = None :param task_manager: TaskManager = None
A TaskManager object that stores indexed tasks. If not set, A TaskManager object that stores indexed tasks. If not set,
...@@ -932,7 +931,7 @@ def get_task_dict( ...@@ -932,7 +931,7 @@ def get_task_dict(
via `include_path` via `include_path`
:return :return
Dictionary of task objects dictionary of task objects
""" """
from lm_eval.api.task import ConfigurableTask, Task from lm_eval.api.task import ConfigurableTask, Task
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment