"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "d18d38c4a4a2885fd43e9d70cea9da7c0b4605fd"
Commit 7fcfb4ac authored by Baber's avatar Baber
Browse files

refactor: simplify docstrings and improve task name matching logic

parent 5e632643
import abc
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
...@@ -83,8 +82,14 @@ class GroupConfig(dict): ...@@ -83,8 +82,14 @@ class GroupConfig(dict):
except (TypeError, OSError): except (TypeError, OSError):
return str(value) return str(value)
@property
def version(self) -> str:
"""Returns the version of the group configuration."""
return self.metadata.get("version", "1.0")
class ConfigurableGroup(abc.ABC): @dataclass
class ConfigurableGroup:
def __init__( def __init__(
self, self,
config: Optional[dict] = None, config: Optional[dict] = None,
......
...@@ -73,19 +73,7 @@ _IGNORE_DIRS = ( ...@@ -73,19 +73,7 @@ _IGNORE_DIRS = (
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None: def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
""" """YAML constructor that ignores !function tags during simple parsing."""
YAML constructor that ignores !function tags during simple parsing.
This is used when mode="simple" to skip function resolution for
faster indexing operations.
Args:
loader: YAML loader instance
node: YAML node being processed
Returns:
None
"""
return None return None
...@@ -129,8 +117,7 @@ def _import_function(qualname: str, *, base_path: Path) -> Callable: ...@@ -129,8 +117,7 @@ def _import_function(qualname: str, *, base_path: Path) -> Callable:
Dynamically import a function from a Python module relative to base_path. Dynamically import a function from a Python module relative to base_path.
This function enables YAML files to reference Python functions using This function enables YAML files to reference Python functions using
the !function tag. It supports dot notation for nested modules and the !function tag. Supports dot notation for nested modules.
caches imported modules for performance.
Args: Args:
qualname: Qualified function name like "my_module.my_function" qualname: Qualified function name like "my_module.my_function"
...@@ -180,7 +167,7 @@ def _parse_yaml_file(path: Path, mode: str) -> dict: ...@@ -180,7 +167,7 @@ def _parse_yaml_file(path: Path, mode: str) -> dict:
@functools.lru_cache(maxsize=4096) @functools.lru_cache(maxsize=4096)
def _get_cached_config(yaml_path: Path, mode: str) -> dict: def _get_cached_config(yaml_path: Path, mode: str) -> dict:
"""Load and cache resolved YAML configs with LRU eviction.""" """Load and cache resolved YAML configs"""
# Parse the YAML file # Parse the YAML file
yaml_config = _parse_yaml_file(yaml_path, mode) yaml_config = _parse_yaml_file(yaml_path, mode)
yaml_dir = yaml_path.parent yaml_dir = yaml_path.parent
...@@ -288,7 +275,7 @@ def load_yaml_config( ...@@ -288,7 +275,7 @@ def load_yaml_config(
return final_cfg return final_cfg
def iter_yaml_files(root: Path) -> Generator[Path, Any, None]: def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, None]:
""" """
Recursively iterate over all YAML files in a directory tree. Recursively iterate over all YAML files in a directory tree.
...@@ -306,7 +293,7 @@ def iter_yaml_files(root: Path) -> Generator[Path, Any, None]: ...@@ -306,7 +293,7 @@ def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
""" """
for p in iglob("**/*.yaml", root_dir=root, recursive=True): for p in iglob("**/*.yaml", root_dir=root, recursive=True):
# ignore check # ignore check
if Path(p).parts[0] in _IGNORE_DIRS: if Path(p).parts[0] in ignore:
continue continue
yield root / p yield root / p
...@@ -352,7 +339,7 @@ class TaskManager: ...@@ -352,7 +339,7 @@ class TaskManager:
verbosity: Optional[str] = None, verbosity: Optional[str] = None,
include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None, include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
include_defaults: bool = True, include_defaults: bool = True,
metadata: Optional[dict] = None, metadata: Optional[dict[str, dict[str, Any]]] = None,
) -> None: ) -> None:
""" """
Initialize the TaskManager. Initialize the TaskManager.
...@@ -548,21 +535,7 @@ class TaskManager: ...@@ -548,21 +535,7 @@ class TaskManager:
return "".join(parts) return "".join(parts)
def match_tasks(self, task_list: list[str]) -> list[str]: def match_tasks(self, task_list: list[str]) -> list[str]:
""" """Match task names using glob-style pattern matching."""
Match task names using pattern matching.
Supports glob-style patterns and returns all matching task names.
Args:
task_list: List of task name patterns to match
Returns:
List of matching task names
Example:
>>> tm.match_tasks(["hella*", "arc_*"])
['hellaswag', 'arc_easy', 'arc_challenge']
"""
return pattern_match(task_list, self.all_tasks) return pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name: str) -> bool: def _name_is_registered(self, name: str) -> bool:
...@@ -738,276 +711,195 @@ class TaskManager: ...@@ -738,276 +711,195 @@ class TaskManager:
else False else False
) )
def _load_individual_task_or_group( ###############################################################################
# NEW: Refactored _load_individual_task_or_group and helper methods #
###############################################################################
def _create_task_object(
self, self,
name_or_config: Optional[Union[str, dict]] = None, cfg: dict,
parent_name: Optional[str] = None, task_name: str,
update_config: Optional[dict] = None, yaml_path: str | None,
) -> Mapping: ) -> dict:
""" """
Load a single task or group with all its configurations and dependencies. Instantiate a single task (ConfigurableTask **or** python-task) from *cfg*.
Returns {task_name: task_object}.
This is the core method for instantiating task objects from either task names
or configuration dictionaries. It handles complex scenarios including:
- Individual tasks and Python class-based tasks
- Groups and their constituent subtasks
- Tags and their associated tasks
- Configuration merging and inheritance
- Duplicate detection and name resolution
- Include processing and YAML inheritance
Args:
name_or_config: Either a task name (str) or configuration dict.
If str, looks up the task in the index.
If dict, processes as inline configuration.
parent_name: Name of parent group (for duplicate detection)
update_config: Additional configuration to merge into task configs
Returns:
Mapping of task/group names to instantiated task objects.
For individual tasks: {task_name: ConfigurableTask}
For groups: {group_name: {subtask1: Task1, subtask2: Task2, ...}}
Example:
Load individual task::
task_dict = tm._load_individual_task_or_group("hellaswag")
# Returns: {"hellaswag": ConfigurableTask(...)}
Load with config override::
task_dict = tm._load_individual_task_or_group(
{"task": "hellaswag", "num_fewshot": 5}
)
Load a group::
group_dict = tm._load_individual_task_or_group("arc_group")
# Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}}
""" """
from lm_eval.api.task import ConfigurableTask, Task from lm_eval.api.task import ConfigurableTask, Task # local import avoids cycle
def _load_task( # ---- include handling ---------------------------------------------------
config: dict, task: str, yaml_path: Optional[str] = None if "include" in cfg:
) -> dict[str, Union["ConfigurableTask", "Task"]]: # keep original name so include merging cannot clobber it
""" orig_name = cfg.get("task", task_name)
Create a single task object from configuration. cfg = {
**load_yaml_config( # recurse once, cached
Handles include processing, Python class instantiation, and metadata injection. yaml_path=Path(yaml_path) if yaml_path else None,
yaml_config={"include": cfg.pop("include")},
Args: mode="full" if yaml_path else "simple",
config: Task configuration dictionary ),
task: Task name **cfg,
yaml_path: Path to source YAML file (for include resolution) "task": orig_name,
}
Returns:
Dictionary mapping task name to instantiated task object
"""
if "include" in config:
# Store the task name to preserve it after include processing
original_task_name = config.get("task", task)
config = {
**load_yaml_config(
yaml_path=Path(yaml_path),
yaml_config={"include": config.pop("include")},
mode="full" if yaml_path else "simple",
),
**config,
"task": original_task_name,
}
# Ensure the task name from the group config is preserved
# This prevents tasks with the same include from being treated as duplicates
if self._config_is_python_task(config): # ---- metadata merge -----------------------------------------------------
if self._class_has_config_in_constructor(config["class"]): if self.metadata is not None:
task_object = config["class"](config=config) cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
else: else:
task_object = config["class"]() cfg["metadata"] = cfg.get("metadata", {})
if isinstance(task_object, ConfigurableTask):
# very scuffed: set task name here. TODO: fixme? # ---- python-task vs YAML-task -------------------------------------------
task_object.config.task = task if self._config_is_python_task(cfg):
cls = cfg["class"]
task_obj: Task
if self._class_has_config_in_constructor(cls):
task_obj = cls(config=cfg)
else: else:
if self.metadata is not None: task_obj = cls()
config["metadata"] = config.get("metadata", {}) | self.metadata # make sure name propagates when the class inherits ConfigurableTask
else: if isinstance(task_obj, ConfigurableTask): # type: ignore
config["metadata"] = config.get("metadata", {}) task_obj.config.task = task_name
task_object = ConfigurableTask(config=config) else:
task_obj = ConfigurableTask(config=cfg) # type: ignore
return {task: task_object}
def _get_group_and_subtask_from_config(
config: dict,
) -> tuple[ConfigurableGroup, list[str]]:
"""
Extract group object and subtask list from group configuration.
Expands any tags in the task list to their constituent tasks. return {task_name: task_obj}
Args: def _create_group_object(
config: Group configuration dictionary self,
cfg: dict,
Returns: parent_name: str | None = None,
Tuple of (ConfigurableGroup, list of subtask names) ) -> tuple[ConfigurableGroup, list[Union[str, dict]]]:
""" """
if self.metadata is not None: Build ConfigurableGroup and return (group_obj, subtask_names).
config["metadata"] = config.get("metadata", {}) | self.metadata Resolves tag expansion.
group_name = ConfigurableGroup(config=config) """
subtask_list = [] if self.metadata is not None:
for task in group_name.config["task"]: cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
if isinstance(task, str) and self._name_is_tag(task):
subtask_list.extend(self._get_tasklist(task)) grp = ConfigurableGroup(config=cfg)
else: subtasks: list[Union[str, dict]] = []
subtask_list.append(task) for t in grp.config["task"]:
return group_name, subtask_list if isinstance(t, str) and self._name_is_tag(t):
subtasks.extend(self._get_tasklist(t))
def _process_group_config( else:
config: dict, update_config: Optional[dict] = None subtasks.append(t)
) -> tuple[dict, Optional[dict]]: return grp, subtasks
"""
Separate group-specific config from task-level config overrides.
Group-only keys (like 'group', 'aggregate') stay with the group, def _load_subtasks(
while other keys become config overrides for constituent tasks. self,
subtasks: list[Union[str, dict]],
parent_name: Union[str, ConfigurableGroup, None],
update_config: dict | None,
) -> Mapping:
"""Return merged mapping of all subtasks, handling duplicates."""
fn = functools.partial(
self._load_individual_task_or_group,
parent_name=parent_name,
update_config=update_config,
)
return dict(collections.ChainMap(*map(fn, reversed(subtasks))))
Args: def _load_individual_task_or_group(
config: Full configuration dictionary self,
update_config: Additional config to merge payload: str | dict,
*,
parent_name: str | None = None,
update_config: dict | None = None,
) -> Mapping:
"""
Public helper that turns *payload* (str task/group/tag **or** dict config)
into a nested Mapping of {name_or_group_obj: task_obj | sub_mapping}.
"""
Returns: # ------------------------------------------------------------------ STRING
Tuple of (group_config, task_update_config) if isinstance(payload, str):
""" # If caller supplied extra overrides, treat as dict immediately
if update_config is not None: if update_config:
config = {**config, **update_config} return self._load_individual_task_or_group(
_update_config = { {"task": payload, **update_config},
k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS parent_name=parent_name,
} )
if not bool(_update_config):
_update_config = None
group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS}
return group_config, _update_config
if isinstance(name_or_config, str):
if update_config is not None:
# Process name_or_config as a dict instead
name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config) or self._name_is_python_task(
name_or_config
):
# Get the yaml_path for this task
yaml_path = self._get_yaml_path(name_or_config)
task_config = self._get_config(name_or_config)
# Handle task_list configs
if "task_list" in task_config:
# Find the specific task entry
task_specific_config = None
for task_entry in task_config["task_list"]:
if (
isinstance(task_entry, dict)
and task_entry.get("task") == name_or_config
):
task_specific_config = task_entry
break
if task_specific_config:
# Create base config without task_list
base_config = {
k: v for k, v in task_config.items() if k != "task_list"
}
# Merge using helper method
task_config = self._merge_task_configs(
base_config, task_specific_config, name_or_config
)
else:
# Task not found in task_list, shouldn't happen if indexing worked correctly
eval_logger.warning(
f"Task {name_or_config} not found in task_list"
)
task_config = {"task": name_or_config}
return _load_task(task_config, task=name_or_config, yaml_path=yaml_path) # ------------ registered TASK (YAML or python) -----------------
else: if self._name_is_task(payload) or self._name_is_python_task(payload):
subtask_list = self._get_tasklist(name_or_config) yaml_path = self._get_yaml_path(payload)
if subtask_list == -1: cfg = self._get_config(payload)
group_config = self._get_config(name_or_config)
group_config, update_config = _process_group_config(group_config) # task_list configs: extract the per-task override ------------
group_name, subtask_list = _get_group_and_subtask_from_config( if "task_list" in cfg:
group_config override = next(
(
entry
for entry in cfg["task_list"]
if isinstance(entry, dict) and entry.get("task") == payload
),
None,
) )
else: base = {k: v for k, v in cfg.items() if k != "task_list"}
if self._name_is_tag(name_or_config): if override:
return self._process_tag_subtasks( cfg = {**base, **override, "task": payload}
name_or_config, return self._create_task_object(cfg, payload, yaml_path)
name_or_config
if isinstance(name_or_config, dict) # ------------ registered GROUP ----------------------------------
else None, if self._name_is_group(payload):
) group_cfg = self._get_config(payload)
else: grp_only = {k: v for k, v in group_cfg.items() if k in GROUP_ONLY_KEYS}
group_name = ConfigurableGroup( grp_obj, subtasks = self._create_group_object(grp_only, parent_name)
config={"group": name_or_config, "task": subtask_list} return {
) grp_obj: self._load_subtasks(subtasks, grp_obj, update_config=None)
}
if isinstance(name_or_config, dict): # ------------ registered TAG ------------------------------------
if self._config_is_task(name_or_config): if self._name_is_tag(payload):
name = name_or_config.pop("task") return self._process_tag_subtasks(payload, update_config=None)
if update_config is not None:
name_or_config = {**name_or_config, **update_config} raise ValueError(f"Unknown task / group / tag name: {payload!r}")
# If the name is registered as a group
if self._name_is_group(name): # ------------------------------------------------------------------- DICT
group_config = self._get_config(name) if isinstance(payload, dict):
# ------------------ simple 'task: name' dict --------------------
group_config, update_config = _process_group_config( if self._config_is_task(payload):
group_config, name_or_config name = payload["task"]
) # override existing registered YAML if exists
group_name, subtask_list = _get_group_and_subtask_from_config( if self._name_is_registered(name):
group_config base_cfg = self._get_config(name)
) yaml_path = self._get_yaml_path(name)
elif self._name_is_tag(name): merged = {**base_cfg, **payload}
return self._process_tag_subtasks(name, name_or_config)
else: else:
merged = payload
yaml_path = None yaml_path = None
if self._name_is_registered(name):
yaml_path = self._get_yaml_path(name)
base_task_config = self._get_config(name)
# Check if this is a duplicate.
if parent_name is not None:
num_duplicate = len(
list(
filter(
lambda x: x.startswith(name),
self.task_group_map[parent_name],
)
)
)
if num_duplicate > 0:
name = f"{name}-{num_duplicate}"
self.task_group_map[parent_name].append(name)
task_config = {
**base_task_config,
**name_or_config,
}
else:
task_config = name_or_config
return _load_task(task_config, task=name, yaml_path=yaml_path)
else:
group_config, update_config = _process_group_config(name_or_config)
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
)
fn = partial( # duplicate-naming guard when inside a group
self._load_individual_task_or_group, if parent_name is not None:
parent_name=group_name, count = len(
update_config=update_config, [
n
for n in self.task_group_map[parent_name]
if n.startswith(name)
]
)
if count:
name = f"{name}-{count}"
self.task_group_map[parent_name].append(name)
return self._create_task_object(merged, name, yaml_path)
# ----------------- literal group dict (task: [...]) -------------
if self._config_is_group(payload):
grp_cfg = {k: v for k, v in payload.items() if k in GROUP_ONLY_KEYS}
sub_override = {
k: v for k, v in payload.items() if k not in GROUP_ONLY_KEYS
} or None
grp_obj, subtasks = self._create_group_object(grp_cfg, parent_name)
return {grp_obj: self._load_subtasks(subtasks, grp_obj, sub_override)}
# ----------------- python-task dict ('class': …) ----------------
if self._config_is_python_task(payload):
name = payload["task"]
return self._create_task_object(payload, name, yaml_path=None)
raise TypeError(
f"_load_individual_task_or_group expected str | dict, got {type(payload)}"
) )
return {
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
}
def load_task_or_group( def load_task_or_group(
self, task_list: Optional[Union[str, list[str]]] = None self, task_list: Optional[Union[str, list[str]]] = None
...@@ -1363,64 +1255,45 @@ def get_task_dict( ...@@ -1363,64 +1255,45 @@ def get_task_dict(
tm = TaskManager(include_path="/custom/tasks") tm = TaskManager(include_path="/custom/tasks")
tasks = get_task_dict(["custom_task"], task_manager=tm) tasks = get_task_dict(["custom_task"], task_manager=tm)
""" """
from lm_eval.api.task import ConfigurableTask, Task from lm_eval.api.task import Task
task_name_from_string_dict = {}
task_name_from_config_dict = {}
task_name_from_object_dict = {}
# Normalize input to list
if isinstance(task_name_list, str): if isinstance(task_name_list, str):
task_name_list = [task_name_list] task_name_list = [task_name_list]
elif isinstance(task_name_list, list): elif not isinstance(task_name_list, list):
if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]):
raise TypeError(
"Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
)
else:
raise TypeError( raise TypeError(
f"Expected a 'str' or 'list' but received {type(task_name_list)}." f"Expected a 'str' or 'list' but received {type(task_name_list)}."
) )
string_task_name_list = [task for task in task_name_list if isinstance(task, str)] # Validate list items
others_task_name_list = [ if not all(isinstance(task, (str, dict, Task)) for task in task_name_list):
task for task in task_name_list if not isinstance(task, str) raise TypeError(
] "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
if len(string_task_name_list) > 0:
if task_manager is None:
task_manager = TaskManager()
task_name_from_string_dict = task_manager.load_task_or_group(
string_task_name_list
) )
for task_element in others_task_name_list: # Ensure we have a task manager
if isinstance(task_element, dict): if task_manager is None:
task_name_from_config_dict = { task_manager = TaskManager()
**task_name_from_config_dict,
**task_manager.load_config(config=task_element),
}
elif isinstance(task_element, Task):
task_name_from_object_dict = {
**task_name_from_object_dict,
get_task_name_from_object(task_element): task_element,
}
if not set(task_name_from_string_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys())
):
raise ValueError
final_task_dict = {
**task_name_from_string_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
# behavior can get odd if one tries to invoke several groups that "compete" for the same task. # Process all items
# (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask final_task_dict = {}
# and we'd be unsure which to use and report.) for task_spec in task_name_list:
# we explicitly check and error in this case. if isinstance(task_spec, Task):
# Pre-instantiated task object
task_name = get_task_name_from_object(task_spec)
if task_name in final_task_dict:
raise ValueError(f"Duplicate task name: {task_name}")
final_task_dict[task_name] = task_spec
else:
# String or dict - use load_task_or_group
result = task_manager.load_task_or_group(task_spec)
# Check for duplicate names
for name in result:
if name in final_task_dict:
raise ValueError(f"Duplicate task name: {name}")
final_task_dict.update(result)
# Check for conflicting group memberships
_check_duplicates(get_subtask_list(final_task_dict)) _check_duplicates(get_subtask_list(final_task_dict))
return final_task_dict return final_task_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment