Commit 85f61d85 authored by Baber's avatar Baber
Browse files

refactor: add type hints

parent 5454e95d
......@@ -88,31 +88,34 @@ class TaskManager:
return task_index
@property
def all_tasks(self):
def all_tasks(self) -> List[str]:
return self._all_tasks
@property
def all_groups(self):
def all_groups(self) -> List[str]:
return self._all_groups
@property
def all_subtasks(self):
def all_subtasks(self) -> List[str]:
return self._all_subtasks
@property
def all_tags(self):
def all_tags(self) -> List[str]:
return self._all_tags
@property
def task_index(self):
def task_index(self) -> Dict[str, Dict[str, Union[str, int, List[str]]]]:
return self._task_index
def list_all_tasks(
self, list_groups=True, list_tags=True, list_subtasks=True
self,
list_groups: bool = True,
list_tags: bool = True,
list_subtasks: bool = True,
) -> str:
from pytablewriter import MarkdownTableWriter
def sanitize_path(path):
def sanitize_path(path: str) -> str:
# don't print full path if we are within the lm_eval/tasks dir !
# if we aren't though, provide the full path.
if "lm_eval/tasks/" in path:
......@@ -210,12 +213,12 @@ class TaskManager:
def _config_is_task_list(self, config: dict) -> bool:
return "task_list" in config and isinstance(config["task_list"], list)
def _get_yaml_path(self, name: str):
def _get_yaml_path(self, name: str) -> Union[str, int]:
if name not in self.task_index:
raise ValueError
return self.task_index[name]["yaml_path"]
def _get_config(self, name):
def _get_config(self, name: str) -> Dict:
if name not in self.task_index:
raise ValueError
yaml_path = self._get_yaml_path(name)
......@@ -224,7 +227,7 @@ class TaskManager:
else:
return utils.load_yaml_config(yaml_path, mode="full")
def _get_tasklist(self, name):
def _get_tasklist(self, name: str) -> Union[List[str], int]:
if self._name_is_task(name):
raise ValueError
return self.task_index[name]["task"]
......@@ -234,10 +237,10 @@ class TaskManager:
task_name: str,
task_type: str,
yaml_path: str,
tasks_and_groups: dict,
config: dict = None,
populate_tags_fn=None,
):
tasks_and_groups: Dict[str, Dict],
config: Optional[Dict] = None,
populate_tags_fn: Optional[callable] = None,
) -> None:
"""Helper method to register a task in the tasks_and_groups dict"""
tasks_and_groups[task_name] = {
"type": task_type,
......@@ -248,8 +251,8 @@ class TaskManager:
populate_tags_fn(config, task_name, tasks_and_groups)
def _merge_task_configs(
self, base_config: dict, task_specific_config: dict, task_name: str
) -> dict:
self, base_config: Dict, task_specific_config: Dict, task_name: str
) -> Dict:
"""Merge base config with task-specific overrides for task_list configs"""
if task_specific_config:
task_specific_config = task_specific_config.copy()
......@@ -257,7 +260,9 @@ class TaskManager:
return {**base_config, **task_specific_config, "task": task_name}
return {**base_config, "task": task_name}
def _process_tag_subtasks(self, tag_name: str, update_config: dict = None):
def _process_tag_subtasks(
self, tag_name: str, update_config: Optional[Dict] = None
) -> Dict:
"""Process subtasks for a tag and return loaded tasks"""
subtask_list = self._get_tasklist(tag_name)
fn = partial(
......@@ -266,7 +271,7 @@ class TaskManager:
)
return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
def _process_alias(self, config, group=None):
def _process_alias(self, config: Dict, group: Optional[str] = None) -> Dict:
# If the group is not the same as the original
# group which the group alias was intended for,
# Set the group_alias to None instead.
......@@ -275,7 +280,7 @@ class TaskManager:
config["group_alias"] = None
return config
def _class_has_config_in_constructor(self, cls):
def _class_has_config_in_constructor(self, cls) -> bool:
constructor = getattr(cls, "__init__", None)
return (
"config" in inspect.signature(constructor).parameters
......@@ -285,11 +290,13 @@ class TaskManager:
def _load_individual_task_or_group(
self,
name_or_config: Optional[Union[str, dict]] = None,
name_or_config: Optional[Union[str, Dict]] = None,
parent_name: Optional[str] = None,
update_config: Optional[dict] = None,
update_config: Optional[Dict] = None,
) -> Mapping:
def _load_task(config, task, yaml_path=None):
def _load_task(
config: Dict, task: str, yaml_path: Optional[str] = None
) -> Dict[str, Union[ConfigurableTask, Task]]:
if "include" in config:
# Store the task name to preserve it after include processing
original_task_name = config.get("task", task)
......@@ -325,8 +332,8 @@ class TaskManager:
return {task: task_object}
def _get_group_and_subtask_from_config(
config: dict,
) -> tuple[ConfigurableGroup, list[str]]:
config: Dict,
) -> tuple[ConfigurableGroup, List[str]]:
if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | self.metadata
group_name = ConfigurableGroup(config=config)
......@@ -339,8 +346,8 @@ class TaskManager:
return group_name, subtask_list
def _process_group_config(
config: dict, update_config: dict = None
) -> tuple[dict, dict]:
config: Dict, update_config: Optional[Dict] = None
) -> tuple[Dict, Optional[Dict]]:
if update_config is not None:
config = {**config, **update_config}
_update_config = {
......@@ -472,7 +479,9 @@ class TaskManager:
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
}
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
def load_task_or_group(
self, task_list: Optional[Union[str, List[str]]] = None
) -> Dict:
"""Loads a dictionary of task objects from a list
:param task_list: Union[str, list] = None
......@@ -494,10 +503,10 @@ class TaskManager:
)
return all_loaded_tasks
def load_config(self, config: Dict):
def load_config(self, config: Dict) -> Mapping:
return self._load_individual_task_or_group(config)
def _get_task_and_group(self, task_dir: Union[str, Path]):
def _get_task_and_group(self, task_dir: Union[str, Path]) -> Dict[str, Dict]:
"""Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, `group` or `tags`.
`task` refer to regular task configs, `python_task` are special
......@@ -520,7 +529,9 @@ class TaskManager:
Dictionary of task names as key and task metadata
"""
def _populate_tags_and_groups(config, task, tasks_and_groups):
def _populate_tags_and_groups(
config: Dict, task: str, tasks_and_groups: Dict[str, Dict]
) -> None:
# TODO: remove group in next release
if "tag" in config:
attr_list = config["tag"]
......@@ -557,7 +568,7 @@ class TaskManager:
for f in file_list:
if f.endswith(".yaml"):
yaml_path = root_path / f
config = utils.load_yaml_config(str(yaml_path), mode="simple")
config = utils.load_yaml_config(yaml_path, mode="simple")
if self._config_is_python_task(config):
# This is a python class config
task = config["task"]
......@@ -629,7 +640,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
return "{dataset_path}".format(**task_config)
def get_task_name_from_object(task_object):
def get_task_name_from_object(task_object: Union[ConfigurableTask, Task]) -> str:
if hasattr(task_object, "config"):
return task_object._config["task"]
......@@ -642,7 +653,7 @@ def get_task_name_from_object(task_object):
)
def _check_duplicates(task_dict: dict) -> None:
def _check_duplicates(task_dict: Dict[str, List[str]]) -> None:
"""helper function solely used in validating get_task_dict output.
Takes the output of lm_eval.evaluator_utils.get_subtask_list and
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
......@@ -672,7 +683,7 @@ def _check_duplicates(task_dict: dict) -> None:
def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, Task]]],
task_manager: Optional[TaskManager] = None,
):
) -> Dict[str, Union[ConfigurableTask, Task]]:
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]]
......
......@@ -11,7 +11,7 @@ import re
from dataclasses import asdict, is_dataclass
from itertools import islice
from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import numpy as np
import yaml
......@@ -441,11 +441,11 @@ def positional_deprecated(fn):
return _wrapper
def ignore_constructor(loader, node):
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> yaml.Node:
return node
def import_function(loader: yaml.Loader, node, yaml_path: Path):
def import_function(loader: yaml.Loader, node: yaml.Node, yaml_path: Path) -> Callable:
function_name = loader.construct_scalar(node)
*module_name, function_name = function_name.split(".")
......@@ -468,8 +468,11 @@ def import_function(loader: yaml.Loader, node, yaml_path: Path):
def load_yaml_config(
yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"
) -> dict:
yaml_path: Optional[Union[str, Path]] = None,
yaml_config: Optional[Dict] = None,
yaml_dir: Optional[Union[str, Path]] = None,
mode: str = "full",
) -> Dict:
# Convert yaml_path to Path object if it's a string
if yaml_path is not None:
yaml_path = Path(yaml_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment