Commit 85f61d85 authored by Baber's avatar Baber
Browse files

refactor: add type hints

parent 5454e95d
...@@ -88,31 +88,34 @@ class TaskManager: ...@@ -88,31 +88,34 @@ class TaskManager:
return task_index return task_index
@property @property
def all_tasks(self): def all_tasks(self) -> List[str]:
return self._all_tasks return self._all_tasks
@property @property
def all_groups(self): def all_groups(self) -> List[str]:
return self._all_groups return self._all_groups
@property @property
def all_subtasks(self): def all_subtasks(self) -> List[str]:
return self._all_subtasks return self._all_subtasks
@property @property
def all_tags(self): def all_tags(self) -> List[str]:
return self._all_tags return self._all_tags
@property @property
def task_index(self): def task_index(self) -> Dict[str, Dict[str, Union[str, int, List[str]]]]:
return self._task_index return self._task_index
def list_all_tasks( def list_all_tasks(
self, list_groups=True, list_tags=True, list_subtasks=True self,
list_groups: bool = True,
list_tags: bool = True,
list_subtasks: bool = True,
) -> str: ) -> str:
from pytablewriter import MarkdownTableWriter from pytablewriter import MarkdownTableWriter
def sanitize_path(path): def sanitize_path(path: str) -> str:
# don't print full path if we are within the lm_eval/tasks dir ! # don't print full path if we are within the lm_eval/tasks dir !
# if we aren't though, provide the full path. # if we aren't though, provide the full path.
if "lm_eval/tasks/" in path: if "lm_eval/tasks/" in path:
...@@ -210,12 +213,12 @@ class TaskManager: ...@@ -210,12 +213,12 @@ class TaskManager:
def _config_is_task_list(self, config: dict) -> bool: def _config_is_task_list(self, config: dict) -> bool:
return "task_list" in config and isinstance(config["task_list"], list) return "task_list" in config and isinstance(config["task_list"], list)
def _get_yaml_path(self, name: str): def _get_yaml_path(self, name: str) -> Union[str, int]:
if name not in self.task_index: if name not in self.task_index:
raise ValueError raise ValueError
return self.task_index[name]["yaml_path"] return self.task_index[name]["yaml_path"]
def _get_config(self, name): def _get_config(self, name: str) -> Dict:
if name not in self.task_index: if name not in self.task_index:
raise ValueError raise ValueError
yaml_path = self._get_yaml_path(name) yaml_path = self._get_yaml_path(name)
...@@ -224,7 +227,7 @@ class TaskManager: ...@@ -224,7 +227,7 @@ class TaskManager:
else: else:
return utils.load_yaml_config(yaml_path, mode="full") return utils.load_yaml_config(yaml_path, mode="full")
def _get_tasklist(self, name): def _get_tasklist(self, name: str) -> Union[List[str], int]:
if self._name_is_task(name): if self._name_is_task(name):
raise ValueError raise ValueError
return self.task_index[name]["task"] return self.task_index[name]["task"]
...@@ -234,10 +237,10 @@ class TaskManager: ...@@ -234,10 +237,10 @@ class TaskManager:
task_name: str, task_name: str,
task_type: str, task_type: str,
yaml_path: str, yaml_path: str,
tasks_and_groups: dict, tasks_and_groups: Dict[str, Dict],
config: dict = None, config: Optional[Dict] = None,
populate_tags_fn=None, populate_tags_fn: Optional[callable] = None,
): ) -> None:
"""Helper method to register a task in the tasks_and_groups dict""" """Helper method to register a task in the tasks_and_groups dict"""
tasks_and_groups[task_name] = { tasks_and_groups[task_name] = {
"type": task_type, "type": task_type,
...@@ -248,8 +251,8 @@ class TaskManager: ...@@ -248,8 +251,8 @@ class TaskManager:
populate_tags_fn(config, task_name, tasks_and_groups) populate_tags_fn(config, task_name, tasks_and_groups)
def _merge_task_configs( def _merge_task_configs(
self, base_config: dict, task_specific_config: dict, task_name: str self, base_config: Dict, task_specific_config: Dict, task_name: str
) -> dict: ) -> Dict:
"""Merge base config with task-specific overrides for task_list configs""" """Merge base config with task-specific overrides for task_list configs"""
if task_specific_config: if task_specific_config:
task_specific_config = task_specific_config.copy() task_specific_config = task_specific_config.copy()
...@@ -257,7 +260,9 @@ class TaskManager: ...@@ -257,7 +260,9 @@ class TaskManager:
return {**base_config, **task_specific_config, "task": task_name} return {**base_config, **task_specific_config, "task": task_name}
return {**base_config, "task": task_name} return {**base_config, "task": task_name}
def _process_tag_subtasks(self, tag_name: str, update_config: dict = None): def _process_tag_subtasks(
self, tag_name: str, update_config: Optional[Dict] = None
) -> Dict:
"""Process subtasks for a tag and return loaded tasks""" """Process subtasks for a tag and return loaded tasks"""
subtask_list = self._get_tasklist(tag_name) subtask_list = self._get_tasklist(tag_name)
fn = partial( fn = partial(
...@@ -266,7 +271,7 @@ class TaskManager: ...@@ -266,7 +271,7 @@ class TaskManager:
) )
return dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
def _process_alias(self, config, group=None): def _process_alias(self, config: Dict, group: Optional[str] = None) -> Dict:
# If the group is not the same as the original # If the group is not the same as the original
# group which the group alias was intended for, # group which the group alias was intended for,
# Set the group_alias to None instead. # Set the group_alias to None instead.
...@@ -275,7 +280,7 @@ class TaskManager: ...@@ -275,7 +280,7 @@ class TaskManager:
config["group_alias"] = None config["group_alias"] = None
return config return config
def _class_has_config_in_constructor(self, cls): def _class_has_config_in_constructor(self, cls) -> bool:
constructor = getattr(cls, "__init__", None) constructor = getattr(cls, "__init__", None)
return ( return (
"config" in inspect.signature(constructor).parameters "config" in inspect.signature(constructor).parameters
...@@ -285,11 +290,13 @@ class TaskManager: ...@@ -285,11 +290,13 @@ class TaskManager:
def _load_individual_task_or_group( def _load_individual_task_or_group(
self, self,
name_or_config: Optional[Union[str, dict]] = None, name_or_config: Optional[Union[str, Dict]] = None,
parent_name: Optional[str] = None, parent_name: Optional[str] = None,
update_config: Optional[dict] = None, update_config: Optional[Dict] = None,
) -> Mapping: ) -> Mapping:
def _load_task(config, task, yaml_path=None): def _load_task(
config: Dict, task: str, yaml_path: Optional[str] = None
) -> Dict[str, Union[ConfigurableTask, Task]]:
if "include" in config: if "include" in config:
# Store the task name to preserve it after include processing # Store the task name to preserve it after include processing
original_task_name = config.get("task", task) original_task_name = config.get("task", task)
...@@ -325,8 +332,8 @@ class TaskManager: ...@@ -325,8 +332,8 @@ class TaskManager:
return {task: task_object} return {task: task_object}
def _get_group_and_subtask_from_config( def _get_group_and_subtask_from_config(
config: dict, config: Dict,
) -> tuple[ConfigurableGroup, list[str]]: ) -> tuple[ConfigurableGroup, List[str]]:
if self.metadata is not None: if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | self.metadata config["metadata"] = config.get("metadata", {}) | self.metadata
group_name = ConfigurableGroup(config=config) group_name = ConfigurableGroup(config=config)
...@@ -339,8 +346,8 @@ class TaskManager: ...@@ -339,8 +346,8 @@ class TaskManager:
return group_name, subtask_list return group_name, subtask_list
def _process_group_config( def _process_group_config(
config: dict, update_config: dict = None config: Dict, update_config: Optional[Dict] = None
) -> tuple[dict, dict]: ) -> tuple[Dict, Optional[Dict]]:
if update_config is not None: if update_config is not None:
config = {**config, **update_config} config = {**config, **update_config}
_update_config = { _update_config = {
...@@ -472,7 +479,9 @@ class TaskManager: ...@@ -472,7 +479,9 @@ class TaskManager:
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
} }
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict: def load_task_or_group(
self, task_list: Optional[Union[str, List[str]]] = None
) -> Dict:
"""Loads a dictionary of task objects from a list """Loads a dictionary of task objects from a list
:param task_list: Union[str, list] = None :param task_list: Union[str, list] = None
...@@ -494,10 +503,10 @@ class TaskManager: ...@@ -494,10 +503,10 @@ class TaskManager:
) )
return all_loaded_tasks return all_loaded_tasks
def load_config(self, config: Dict): def load_config(self, config: Dict) -> Mapping:
return self._load_individual_task_or_group(config) return self._load_individual_task_or_group(config)
def _get_task_and_group(self, task_dir: Union[str, Path]): def _get_task_and_group(self, task_dir: Union[str, Path]) -> Dict[str, Dict]:
"""Creates a dictionary of tasks index with the following metadata, """Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, `group` or `tags`. - `type`, that can be either `task`, `python_task`, `group` or `tags`.
`task` refer to regular task configs, `python_task` are special `task` refer to regular task configs, `python_task` are special
...@@ -520,7 +529,9 @@ class TaskManager: ...@@ -520,7 +529,9 @@ class TaskManager:
Dictionary of task names as key and task metadata Dictionary of task names as key and task metadata
""" """
def _populate_tags_and_groups(config, task, tasks_and_groups): def _populate_tags_and_groups(
config: Dict, task: str, tasks_and_groups: Dict[str, Dict]
) -> None:
# TODO: remove group in next release # TODO: remove group in next release
if "tag" in config: if "tag" in config:
attr_list = config["tag"] attr_list = config["tag"]
...@@ -557,7 +568,7 @@ class TaskManager: ...@@ -557,7 +568,7 @@ class TaskManager:
for f in file_list: for f in file_list:
if f.endswith(".yaml"): if f.endswith(".yaml"):
yaml_path = root_path / f yaml_path = root_path / f
config = utils.load_yaml_config(str(yaml_path), mode="simple") config = utils.load_yaml_config(yaml_path, mode="simple")
if self._config_is_python_task(config): if self._config_is_python_task(config):
# This is a python class config # This is a python class config
task = config["task"] task = config["task"]
...@@ -629,7 +640,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str: ...@@ -629,7 +640,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
return "{dataset_path}".format(**task_config) return "{dataset_path}".format(**task_config)
def get_task_name_from_object(task_object): def get_task_name_from_object(task_object: Union[ConfigurableTask, Task]) -> str:
if hasattr(task_object, "config"): if hasattr(task_object, "config"):
return task_object._config["task"] return task_object._config["task"]
...@@ -642,7 +653,7 @@ def get_task_name_from_object(task_object): ...@@ -642,7 +653,7 @@ def get_task_name_from_object(task_object):
) )
def _check_duplicates(task_dict: dict) -> None: def _check_duplicates(task_dict: Dict[str, List[str]]) -> None:
"""helper function solely used in validating get_task_dict output. """helper function solely used in validating get_task_dict output.
Takes the output of lm_eval.evaluator_utils.get_subtask_list and Takes the output of lm_eval.evaluator_utils.get_subtask_list and
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
...@@ -672,7 +683,7 @@ def _check_duplicates(task_dict: dict) -> None: ...@@ -672,7 +683,7 @@ def _check_duplicates(task_dict: dict) -> None:
def get_task_dict( def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, Task]]], task_name_list: Union[str, List[Union[str, Dict, Task]]],
task_manager: Optional[TaskManager] = None, task_manager: Optional[TaskManager] = None,
): ) -> Dict[str, Union[ConfigurableTask, Task]]:
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object. """Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]] :param task_name_list: List[Union[str, Dict, Task]]
......
...@@ -11,7 +11,7 @@ import re ...@@ -11,7 +11,7 @@ import re
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import numpy as np import numpy as np
import yaml import yaml
...@@ -441,11 +441,11 @@ def positional_deprecated(fn): ...@@ -441,11 +441,11 @@ def positional_deprecated(fn):
return _wrapper return _wrapper
def ignore_constructor(loader, node): def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> yaml.Node:
return node return node
def import_function(loader: yaml.Loader, node, yaml_path: Path): def import_function(loader: yaml.Loader, node: yaml.Node, yaml_path: Path) -> Callable:
function_name = loader.construct_scalar(node) function_name = loader.construct_scalar(node)
*module_name, function_name = function_name.split(".") *module_name, function_name = function_name.split(".")
...@@ -468,8 +468,11 @@ def import_function(loader: yaml.Loader, node, yaml_path: Path): ...@@ -468,8 +468,11 @@ def import_function(loader: yaml.Loader, node, yaml_path: Path):
def load_yaml_config( def load_yaml_config(
yaml_path=None, yaml_config=None, yaml_dir=None, mode="full" yaml_path: Optional[Union[str, Path]] = None,
) -> dict: yaml_config: Optional[Dict] = None,
yaml_dir: Optional[Union[str, Path]] = None,
mode: str = "full",
) -> Dict:
# Convert yaml_path to Path object if it's a string # Convert yaml_path to Path object if it's a string
if yaml_path is not None: if yaml_path is not None:
yaml_path = Path(yaml_path) yaml_path = Path(yaml_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment