"src/diffusers/pipelines/lumina2/pipeline_lumina2.py" did not exist on "a2d424eb2ed2be3f1d77ad9a5a1f309825c6c863"
Commit 3c969207 authored by Baber's avatar Baber
Browse files

nit

parent 6fc2ac49
...@@ -5,14 +5,12 @@ import pathlib ...@@ -5,14 +5,12 @@ import pathlib
import sys import sys
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
aggregate_subtask_metrics, aggregate_subtask_metrics,
mean, mean,
pooled_sample_stderr, pooled_sample_stderr,
stderr_for_metric, stderr_for_metric,
) )
from lm_eval.api.task import Task
from lm_eval.utils import positional_deprecated from lm_eval.utils import positional_deprecated
...@@ -153,6 +151,9 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]: ...@@ -153,6 +151,9 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]:
def get_subtask_list(task_dict, task_root=None, depth=0): def get_subtask_list(task_dict, task_root=None, depth=0):
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.task import Task
subtask_list = {} subtask_list = {}
for group_obj, task_obj in task_dict.items(): for group_obj, task_obj in task_dict.items():
if isinstance(group_obj, ConfigurableGroup): if isinstance(group_obj, ConfigurableGroup):
...@@ -224,6 +225,8 @@ def prepare_print_tasks( ...@@ -224,6 +225,8 @@ def prepare_print_tasks(
task_depth=0, task_depth=0,
group_depth=0, group_depth=0,
) -> Tuple[dict, dict]: ) -> Tuple[dict, dict]:
from lm_eval.api.task import Task
""" """
@param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its @param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
value is a list of task names. value is a list of task names.
...@@ -238,6 +241,7 @@ def prepare_print_tasks( ...@@ -238,6 +241,7 @@ def prepare_print_tasks(
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing. Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
""" """
from lm_eval.api.group import ConfigurableGroup
def _sort_task_dict(task_dict): def _sort_task_dict(task_dict):
""" """
...@@ -395,6 +399,9 @@ def consolidate_group_results( ...@@ -395,6 +399,9 @@ def consolidate_group_results(
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple. The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
In the top-level invocation of this function, task_aggregation_list is ignored. In the top-level invocation of this function, task_aggregation_list is ignored.
""" """
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.task import Task
if task_root is None: if task_root is None:
task_root = {} task_root = {}
......
...@@ -7,15 +7,29 @@ import sys ...@@ -7,15 +7,29 @@ import sys
from functools import partial from functools import partial
from glob import iglob from glob import iglob
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Union from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Mapping,
Optional,
Union,
)
import yaml import yaml
from memory_profiler import profile
from yaml import YAMLError from yaml import YAMLError
from lm_eval import utils
from lm_eval.api.group import ConfigurableGroup, GroupConfig from lm_eval.api.group import ConfigurableGroup, GroupConfig
from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.evaluator_utils import get_subtask_list from lm_eval.evaluator_utils import get_subtask_list
from lm_eval.utils import pattern_match, setup_logging
if TYPE_CHECKING:
from lm_eval.api.task import ConfigurableTask, Task
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys()) GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
...@@ -190,6 +204,7 @@ class TaskManager: ...@@ -190,6 +204,7 @@ class TaskManager:
""" """
@profile
def __init__( def __init__(
self, self,
verbosity: Optional[str] = None, verbosity: Optional[str] = None,
...@@ -198,7 +213,7 @@ class TaskManager: ...@@ -198,7 +213,7 @@ class TaskManager:
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
) -> None: ) -> None:
if verbosity is not None: if verbosity is not None:
utils.setup_logging(verbosity) setup_logging(verbosity)
self.include_path = include_path self.include_path = include_path
self.metadata = metadata self.metadata = metadata
self._task_index = self.initialize_tasks( self._task_index = self.initialize_tasks(
...@@ -222,6 +237,7 @@ class TaskManager: ...@@ -222,6 +237,7 @@ class TaskManager:
self.task_group_map = collections.defaultdict(list) self.task_group_map = collections.defaultdict(list)
@profile
def initialize_tasks( def initialize_tasks(
self, self,
include_path: Optional[Union[str, Path, List[Union[str, Path]]]] = None, include_path: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
...@@ -375,7 +391,7 @@ class TaskManager: ...@@ -375,7 +391,7 @@ class TaskManager:
return "".join(parts) return "".join(parts)
def match_tasks(self, task_list: list[str]) -> list[str]: def match_tasks(self, task_list: list[str]) -> list[str]:
return utils.pattern_match(task_list, self.all_tasks) return pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name: str) -> bool: def _name_is_registered(self, name: str) -> bool:
return name in self.all_tasks return name in self.all_tasks
...@@ -492,9 +508,11 @@ class TaskManager: ...@@ -492,9 +508,11 @@ class TaskManager:
parent_name: Optional[str] = None, parent_name: Optional[str] = None,
update_config: Optional[Dict] = None, update_config: Optional[Dict] = None,
) -> Mapping: ) -> Mapping:
from lm_eval.api.task import ConfigurableTask, Task
def _load_task( def _load_task(
config: Dict, task: str, yaml_path: Optional[str] = None config: Dict, task: str, yaml_path: Optional[str] = None
) -> Dict[str, Union[ConfigurableTask, Task]]: ) -> Dict[str, Union["ConfigurableTask", "Task"]]:
if "include" in config: if "include" in config:
# Store the task name to preserve it after include processing # Store the task name to preserve it after include processing
original_task_name = config.get("task", task) original_task_name = config.get("task", task)
...@@ -704,6 +722,7 @@ class TaskManager: ...@@ -704,6 +722,7 @@ class TaskManager:
def load_config(self, config: Dict) -> Mapping: def load_config(self, config: Dict) -> Mapping:
return self._load_individual_task_or_group(config) return self._load_individual_task_or_group(config)
@profile
def _get_task_and_group(self, task_dir: Union[str, Path]) -> Dict[str, Dict]: def _get_task_and_group(self, task_dir: Union[str, Path]) -> Dict[str, Dict]:
"""Creates a dictionary of tasks index with the following metadata, """Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, `group` or `tags`. - `type`, that can be either `task`, `python_task`, `group` or `tags`.
...@@ -839,7 +858,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str: ...@@ -839,7 +858,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
return "{dataset_path}".format(**task_config) return "{dataset_path}".format(**task_config)
def get_task_name_from_object(task_object: Union[ConfigurableTask, Task]) -> str: def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) -> str:
if hasattr(task_object, "config"): if hasattr(task_object, "config"):
return task_object._config["task"] return task_object._config["task"]
...@@ -879,10 +898,11 @@ def _check_duplicates(task_dict: Dict[str, List[str]]) -> None: ...@@ -879,10 +898,11 @@ def _check_duplicates(task_dict: Dict[str, List[str]]) -> None:
) )
@profile
def get_task_dict( def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, Task]]], task_name_list: Union[str, List[Union[str, Dict, "Task"]]],
task_manager: Optional[TaskManager] = None, task_manager: Optional[TaskManager] = None,
) -> Dict[str, Union[ConfigurableTask, Task]]: ) -> Dict[str, Union["ConfigurableTask", "Task"]]:
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object. """Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]] :param task_name_list: List[Union[str, Dict, Task]]
...@@ -896,6 +916,7 @@ def get_task_dict( ...@@ -896,6 +916,7 @@ def get_task_dict(
:return :return
Dictionary of task objects Dictionary of task objects
""" """
from lm_eval.api.task import ConfigurableTask, Task
task_name_from_string_dict = {} task_name_from_string_dict = {}
task_name_from_config_dict = {} task_name_from_config_dict = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment