Commit 3c969207 authored by Baber's avatar Baber
Browse files

nit

parent 6fc2ac49
......@@ -5,14 +5,12 @@ import pathlib
import sys
from typing import List, Optional, Tuple, Union
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.metrics import (
aggregate_subtask_metrics,
mean,
pooled_sample_stderr,
stderr_for_metric,
)
from lm_eval.api.task import Task
from lm_eval.utils import positional_deprecated
......@@ -153,6 +151,9 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]:
def get_subtask_list(task_dict, task_root=None, depth=0):
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.task import Task
subtask_list = {}
for group_obj, task_obj in task_dict.items():
if isinstance(group_obj, ConfigurableGroup):
......@@ -224,6 +225,8 @@ def prepare_print_tasks(
task_depth=0,
group_depth=0,
) -> Tuple[dict, dict]:
from lm_eval.api.task import Task
"""
@param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
value is a list of task names.
......@@ -238,6 +241,7 @@ def prepare_print_tasks(
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
"""
from lm_eval.api.group import ConfigurableGroup
def _sort_task_dict(task_dict):
"""
......@@ -395,6 +399,9 @@ def consolidate_group_results(
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
In the top-level invocation of this function, task_aggregation_list is ignored.
"""
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.task import Task
if task_root is None:
task_root = {}
......
......@@ -7,15 +7,29 @@ import sys
from functools import partial
from glob import iglob
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Mapping,
Optional,
Union,
)
import yaml
from memory_profiler import profile
from yaml import YAMLError
from lm_eval import utils
from lm_eval.api.group import ConfigurableGroup, GroupConfig
from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.evaluator_utils import get_subtask_list
from lm_eval.utils import pattern_match, setup_logging
if TYPE_CHECKING:
from lm_eval.api.task import ConfigurableTask, Task
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
......@@ -190,6 +204,7 @@ class TaskManager:
"""
@profile
def __init__(
self,
verbosity: Optional[str] = None,
......@@ -198,7 +213,7 @@ class TaskManager:
metadata: Optional[dict] = None,
) -> None:
if verbosity is not None:
utils.setup_logging(verbosity)
setup_logging(verbosity)
self.include_path = include_path
self.metadata = metadata
self._task_index = self.initialize_tasks(
......@@ -222,6 +237,7 @@ class TaskManager:
self.task_group_map = collections.defaultdict(list)
@profile
def initialize_tasks(
self,
include_path: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
......@@ -375,7 +391,7 @@ class TaskManager:
return "".join(parts)
def match_tasks(self, task_list: list[str]) -> list[str]:
return utils.pattern_match(task_list, self.all_tasks)
return pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name: str) -> bool:
return name in self.all_tasks
......@@ -492,9 +508,11 @@ class TaskManager:
parent_name: Optional[str] = None,
update_config: Optional[Dict] = None,
) -> Mapping:
from lm_eval.api.task import ConfigurableTask, Task
def _load_task(
config: Dict, task: str, yaml_path: Optional[str] = None
) -> Dict[str, Union[ConfigurableTask, Task]]:
) -> Dict[str, Union["ConfigurableTask", "Task"]]:
if "include" in config:
# Store the task name to preserve it after include processing
original_task_name = config.get("task", task)
......@@ -704,6 +722,7 @@ class TaskManager:
def load_config(self, config: Dict) -> Mapping:
return self._load_individual_task_or_group(config)
@profile
def _get_task_and_group(self, task_dir: Union[str, Path]) -> Dict[str, Dict]:
"""Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, `group` or `tags`.
......@@ -839,7 +858,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
return "{dataset_path}".format(**task_config)
def get_task_name_from_object(task_object: Union[ConfigurableTask, Task]) -> str:
def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) -> str:
if hasattr(task_object, "config"):
return task_object._config["task"]
......@@ -879,10 +898,11 @@ def _check_duplicates(task_dict: Dict[str, List[str]]) -> None:
)
@profile
def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, Task]]],
task_name_list: Union[str, List[Union[str, Dict, "Task"]]],
task_manager: Optional[TaskManager] = None,
) -> Dict[str, Union[ConfigurableTask, Task]]:
) -> Dict[str, Union["ConfigurableTask", "Task"]]:
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]]
......@@ -896,6 +916,7 @@ def get_task_dict(
:return
Dictionary of task objects
"""
from lm_eval.api.task import ConfigurableTask, Task
task_name_from_string_dict = {}
task_name_from_config_dict = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment