Commit 7fcfb4ac authored by Baber's avatar Baber
Browse files

refactor: simplify docstrings and improve task name matching logic

parent 5e632643
import abc
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, Callable, List, Optional, Union
......@@ -83,8 +82,14 @@ class GroupConfig(dict):
except (TypeError, OSError):
return str(value)
@property
def version(self) -> str:
"""Returns the version of the group configuration."""
return self.metadata.get("version", "1.0")
class ConfigurableGroup(abc.ABC):
@dataclass
class ConfigurableGroup:
def __init__(
self,
config: Optional[dict] = None,
......
......@@ -73,19 +73,7 @@ _IGNORE_DIRS = (
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
"""
YAML constructor that ignores !function tags during simple parsing.
This is used when mode="simple" to skip function resolution for
faster indexing operations.
Args:
loader: YAML loader instance
node: YAML node being processed
Returns:
None
"""
"""YAML constructor that ignores !function tags during simple parsing."""
return None
......@@ -129,8 +117,7 @@ def _import_function(qualname: str, *, base_path: Path) -> Callable:
Dynamically import a function from a Python module relative to base_path.
This function enables YAML files to reference Python functions using
the !function tag. It supports dot notation for nested modules and
caches imported modules for performance.
the !function tag. Supports dot notation for nested modules.
Args:
qualname: Qualified function name like "my_module.my_function"
......@@ -180,7 +167,7 @@ def _parse_yaml_file(path: Path, mode: str) -> dict:
@functools.lru_cache(maxsize=4096)
def _get_cached_config(yaml_path: Path, mode: str) -> dict:
"""Load and cache resolved YAML configs with LRU eviction."""
"""Load and cache resolved YAML configs"""
# Parse the YAML file
yaml_config = _parse_yaml_file(yaml_path, mode)
yaml_dir = yaml_path.parent
......@@ -288,7 +275,7 @@ def load_yaml_config(
return final_cfg
def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, None]:
"""
Recursively iterate over all YAML files in a directory tree.
......@@ -306,7 +293,7 @@ def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
"""
for p in iglob("**/*.yaml", root_dir=root, recursive=True):
# ignore check
if Path(p).parts[0] in _IGNORE_DIRS:
if Path(p).parts[0] in ignore:
continue
yield root / p
......@@ -352,7 +339,7 @@ class TaskManager:
verbosity: Optional[str] = None,
include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
include_defaults: bool = True,
metadata: Optional[dict] = None,
metadata: Optional[dict[str, dict[str, Any]]] = None,
) -> None:
"""
Initialize the TaskManager.
......@@ -548,21 +535,7 @@ class TaskManager:
return "".join(parts)
def match_tasks(self, task_list: list[str]) -> list[str]:
"""
Match task names using pattern matching.
Supports glob-style patterns and returns all matching task names.
Args:
task_list: List of task name patterns to match
Returns:
List of matching task names
Example:
>>> tm.match_tasks(["hella*", "arc_*"])
['hellaswag', 'arc_easy', 'arc_challenge']
"""
"""Match task names using glob-style pattern matching."""
return pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name: str) -> bool:
......@@ -738,276 +711,195 @@ class TaskManager:
else False
)
def _load_individual_task_or_group(
###############################################################################
# NEW: Refactored _load_individual_task_or_group and helper methods #
###############################################################################
def _create_task_object(
self,
name_or_config: Optional[Union[str, dict]] = None,
parent_name: Optional[str] = None,
update_config: Optional[dict] = None,
) -> Mapping:
cfg: dict,
task_name: str,
yaml_path: str | None,
) -> dict:
"""
Load a single task or group with all its configurations and dependencies.
This is the core method for instantiating task objects from either task names
or configuration dictionaries. It handles complex scenarios including:
- Individual tasks and Python class-based tasks
- Groups and their constituent subtasks
- Tags and their associated tasks
- Configuration merging and inheritance
- Duplicate detection and name resolution
- Include processing and YAML inheritance
Args:
name_or_config: Either a task name (str) or configuration dict.
If str, looks up the task in the index.
If dict, processes as inline configuration.
parent_name: Name of parent group (for duplicate detection)
update_config: Additional configuration to merge into task configs
Returns:
Mapping of task/group names to instantiated task objects.
For individual tasks: {task_name: ConfigurableTask}
For groups: {group_name: {subtask1: Task1, subtask2: Task2, ...}}
Example:
Load individual task::
task_dict = tm._load_individual_task_or_group("hellaswag")
# Returns: {"hellaswag": ConfigurableTask(...)}
Load with config override::
task_dict = tm._load_individual_task_or_group(
{"task": "hellaswag", "num_fewshot": 5}
)
Load a group::
group_dict = tm._load_individual_task_or_group("arc_group")
# Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}}
Instantiate a single task (ConfigurableTask **or** python-task) from *cfg*.
Returns {task_name: task_object}.
"""
from lm_eval.api.task import ConfigurableTask, Task
def _load_task(
config: dict, task: str, yaml_path: Optional[str] = None
) -> dict[str, Union["ConfigurableTask", "Task"]]:
"""
Create a single task object from configuration.
Handles include processing, Python class instantiation, and metadata injection.
Args:
config: Task configuration dictionary
task: Task name
yaml_path: Path to source YAML file (for include resolution)
Returns:
Dictionary mapping task name to instantiated task object
"""
if "include" in config:
# Store the task name to preserve it after include processing
original_task_name = config.get("task", task)
config = {
**load_yaml_config(
yaml_path=Path(yaml_path),
yaml_config={"include": config.pop("include")},
mode="full" if yaml_path else "simple",
),
**config,
"task": original_task_name,
}
# Ensure the task name from the group config is preserved
# This prevents tasks with the same include from being treated as duplicates
from lm_eval.api.task import ConfigurableTask, Task # local import avoids cycle
# ---- include handling ---------------------------------------------------
if "include" in cfg:
# keep original name so include merging cannot clobber it
orig_name = cfg.get("task", task_name)
cfg = {
**load_yaml_config( # recurse once, cached
yaml_path=Path(yaml_path) if yaml_path else None,
yaml_config={"include": cfg.pop("include")},
mode="full" if yaml_path else "simple",
),
**cfg,
"task": orig_name,
}
if self._config_is_python_task(config):
if self._class_has_config_in_constructor(config["class"]):
task_object = config["class"](config=config)
else:
task_object = config["class"]()
if isinstance(task_object, ConfigurableTask):
# very scuffed: set task name here. TODO: fixme?
task_object.config.task = task
# ---- metadata merge -----------------------------------------------------
if self.metadata is not None:
cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
else:
cfg["metadata"] = cfg.get("metadata", {})
# ---- python-task vs YAML-task -------------------------------------------
if self._config_is_python_task(cfg):
cls = cfg["class"]
task_obj: Task
if self._class_has_config_in_constructor(cls):
task_obj = cls(config=cfg)
else:
if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | self.metadata
else:
config["metadata"] = config.get("metadata", {})
task_object = ConfigurableTask(config=config)
return {task: task_object}
def _get_group_and_subtask_from_config(
config: dict,
) -> tuple[ConfigurableGroup, list[str]]:
"""
Extract group object and subtask list from group configuration.
task_obj = cls()
# make sure name propagates when the class inherits ConfigurableTask
if isinstance(task_obj, ConfigurableTask): # type: ignore
task_obj.config.task = task_name
else:
task_obj = ConfigurableTask(config=cfg) # type: ignore
Expands any tags in the task list to their constituent tasks.
return {task_name: task_obj}
Args:
config: Group configuration dictionary
Returns:
Tuple of (ConfigurableGroup, list of subtask names)
"""
if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | self.metadata
group_name = ConfigurableGroup(config=config)
subtask_list = []
for task in group_name.config["task"]:
if isinstance(task, str) and self._name_is_tag(task):
subtask_list.extend(self._get_tasklist(task))
else:
subtask_list.append(task)
return group_name, subtask_list
def _process_group_config(
config: dict, update_config: Optional[dict] = None
) -> tuple[dict, Optional[dict]]:
"""
Separate group-specific config from task-level config overrides.
def _create_group_object(
self,
cfg: dict,
parent_name: str | None = None,
) -> tuple[ConfigurableGroup, list[Union[str, dict]]]:
"""
Build ConfigurableGroup and return (group_obj, subtask_names).
Resolves tag expansion.
"""
if self.metadata is not None:
cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
grp = ConfigurableGroup(config=cfg)
subtasks: list[Union[str, dict]] = []
for t in grp.config["task"]:
if isinstance(t, str) and self._name_is_tag(t):
subtasks.extend(self._get_tasklist(t))
else:
subtasks.append(t)
return grp, subtasks
Group-only keys (like 'group', 'aggregate') stay with the group,
while other keys become config overrides for constituent tasks.
def _load_subtasks(
self,
subtasks: list[Union[str, dict]],
parent_name: Union[str, ConfigurableGroup, None],
update_config: dict | None,
) -> Mapping:
"""Return merged mapping of all subtasks, handling duplicates."""
fn = functools.partial(
self._load_individual_task_or_group,
parent_name=parent_name,
update_config=update_config,
)
return dict(collections.ChainMap(*map(fn, reversed(subtasks))))
Args:
config: Full configuration dictionary
update_config: Additional config to merge
def _load_individual_task_or_group(
self,
payload: str | dict,
*,
parent_name: str | None = None,
update_config: dict | None = None,
) -> Mapping:
"""
Public helper that turns *payload* (str task/group/tag **or** dict config)
into a nested Mapping of {name_or_group_obj: task_obj | sub_mapping}.
"""
Returns:
Tuple of (group_config, task_update_config)
"""
if update_config is not None:
config = {**config, **update_config}
_update_config = {
k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS
}
if not bool(_update_config):
_update_config = None
group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS}
return group_config, _update_config
if isinstance(name_or_config, str):
if update_config is not None:
# Process name_or_config as a dict instead
name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config) or self._name_is_python_task(
name_or_config
):
# Get the yaml_path for this task
yaml_path = self._get_yaml_path(name_or_config)
task_config = self._get_config(name_or_config)
# Handle task_list configs
if "task_list" in task_config:
# Find the specific task entry
task_specific_config = None
for task_entry in task_config["task_list"]:
if (
isinstance(task_entry, dict)
and task_entry.get("task") == name_or_config
):
task_specific_config = task_entry
break
if task_specific_config:
# Create base config without task_list
base_config = {
k: v for k, v in task_config.items() if k != "task_list"
}
# Merge using helper method
task_config = self._merge_task_configs(
base_config, task_specific_config, name_or_config
)
else:
# Task not found in task_list, shouldn't happen if indexing worked correctly
eval_logger.warning(
f"Task {name_or_config} not found in task_list"
)
task_config = {"task": name_or_config}
# ------------------------------------------------------------------ STRING
if isinstance(payload, str):
# If caller supplied extra overrides, treat as dict immediately
if update_config:
return self._load_individual_task_or_group(
{"task": payload, **update_config},
parent_name=parent_name,
)
return _load_task(task_config, task=name_or_config, yaml_path=yaml_path)
else:
subtask_list = self._get_tasklist(name_or_config)
if subtask_list == -1:
group_config = self._get_config(name_or_config)
group_config, update_config = _process_group_config(group_config)
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
# ------------ registered TASK (YAML or python) -----------------
if self._name_is_task(payload) or self._name_is_python_task(payload):
yaml_path = self._get_yaml_path(payload)
cfg = self._get_config(payload)
# task_list configs: extract the per-task override ------------
if "task_list" in cfg:
override = next(
(
entry
for entry in cfg["task_list"]
if isinstance(entry, dict) and entry.get("task") == payload
),
None,
)
else:
if self._name_is_tag(name_or_config):
return self._process_tag_subtasks(
name_or_config,
name_or_config
if isinstance(name_or_config, dict)
else None,
)
else:
group_name = ConfigurableGroup(
config={"group": name_or_config, "task": subtask_list}
)
base = {k: v for k, v in cfg.items() if k != "task_list"}
if override:
cfg = {**base, **override, "task": payload}
return self._create_task_object(cfg, payload, yaml_path)
# ------------ registered GROUP ----------------------------------
if self._name_is_group(payload):
group_cfg = self._get_config(payload)
grp_only = {k: v for k, v in group_cfg.items() if k in GROUP_ONLY_KEYS}
grp_obj, subtasks = self._create_group_object(grp_only, parent_name)
return {
grp_obj: self._load_subtasks(subtasks, grp_obj, update_config=None)
}
if isinstance(name_or_config, dict):
if self._config_is_task(name_or_config):
name = name_or_config.pop("task")
if update_config is not None:
name_or_config = {**name_or_config, **update_config}
# If the name is registered as a group
if self._name_is_group(name):
group_config = self._get_config(name)
group_config, update_config = _process_group_config(
group_config, name_or_config
)
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
)
elif self._name_is_tag(name):
return self._process_tag_subtasks(name, name_or_config)
# ------------ registered TAG ------------------------------------
if self._name_is_tag(payload):
return self._process_tag_subtasks(payload, update_config=None)
raise ValueError(f"Unknown task / group / tag name: {payload!r}")
# ------------------------------------------------------------------- DICT
if isinstance(payload, dict):
# ------------------ simple 'task: name' dict --------------------
if self._config_is_task(payload):
name = payload["task"]
# override existing registered YAML if exists
if self._name_is_registered(name):
base_cfg = self._get_config(name)
yaml_path = self._get_yaml_path(name)
merged = {**base_cfg, **payload}
else:
merged = payload
yaml_path = None
if self._name_is_registered(name):
yaml_path = self._get_yaml_path(name)
base_task_config = self._get_config(name)
# Check if this is a duplicate.
if parent_name is not None:
num_duplicate = len(
list(
filter(
lambda x: x.startswith(name),
self.task_group_map[parent_name],
)
)
)
if num_duplicate > 0:
name = f"{name}-{num_duplicate}"
self.task_group_map[parent_name].append(name)
task_config = {
**base_task_config,
**name_or_config,
}
else:
task_config = name_or_config
return _load_task(task_config, task=name, yaml_path=yaml_path)
else:
group_config, update_config = _process_group_config(name_or_config)
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
)
fn = partial(
self._load_individual_task_or_group,
parent_name=group_name,
update_config=update_config,
# duplicate-naming guard when inside a group
if parent_name is not None:
count = len(
[
n
for n in self.task_group_map[parent_name]
if n.startswith(name)
]
)
if count:
name = f"{name}-{count}"
self.task_group_map[parent_name].append(name)
return self._create_task_object(merged, name, yaml_path)
# ----------------- literal group dict (task: [...]) -------------
if self._config_is_group(payload):
grp_cfg = {k: v for k, v in payload.items() if k in GROUP_ONLY_KEYS}
sub_override = {
k: v for k, v in payload.items() if k not in GROUP_ONLY_KEYS
} or None
grp_obj, subtasks = self._create_group_object(grp_cfg, parent_name)
return {grp_obj: self._load_subtasks(subtasks, grp_obj, sub_override)}
# ----------------- python-task dict ('class': …) ----------------
if self._config_is_python_task(payload):
name = payload["task"]
return self._create_task_object(payload, name, yaml_path=None)
raise TypeError(
f"_load_individual_task_or_group expected str | dict, got {type(payload)}"
)
return {
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
}
def load_task_or_group(
self, task_list: Optional[Union[str, list[str]]] = None
......@@ -1363,64 +1255,45 @@ def get_task_dict(
tm = TaskManager(include_path="/custom/tasks")
tasks = get_task_dict(["custom_task"], task_manager=tm)
"""
from lm_eval.api.task import ConfigurableTask, Task
task_name_from_string_dict = {}
task_name_from_config_dict = {}
task_name_from_object_dict = {}
from lm_eval.api.task import Task
# Normalize input to list
if isinstance(task_name_list, str):
task_name_list = [task_name_list]
elif isinstance(task_name_list, list):
if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]):
raise TypeError(
"Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
)
else:
elif not isinstance(task_name_list, list):
raise TypeError(
f"Expected a 'str' or 'list' but received {type(task_name_list)}."
)
string_task_name_list = [task for task in task_name_list if isinstance(task, str)]
others_task_name_list = [
task for task in task_name_list if not isinstance(task, str)
]
if len(string_task_name_list) > 0:
if task_manager is None:
task_manager = TaskManager()
task_name_from_string_dict = task_manager.load_task_or_group(
string_task_name_list
# Validate list items
if not all(isinstance(task, (str, dict, Task)) for task in task_name_list):
raise TypeError(
"Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
)
for task_element in others_task_name_list:
if isinstance(task_element, dict):
task_name_from_config_dict = {
**task_name_from_config_dict,
**task_manager.load_config(config=task_element),
}
elif isinstance(task_element, Task):
task_name_from_object_dict = {
**task_name_from_object_dict,
get_task_name_from_object(task_element): task_element,
}
if not set(task_name_from_string_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys())
):
raise ValueError
final_task_dict = {
**task_name_from_string_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
# Ensure we have a task manager
if task_manager is None:
task_manager = TaskManager()
# behavior can get odd if one tries to invoke several groups that "compete" for the same task.
# (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask
# and we'd be unsure which to use and report.)
# we explicitly check and error in this case.
# Process all items
final_task_dict = {}
for task_spec in task_name_list:
if isinstance(task_spec, Task):
# Pre-instantiated task object
task_name = get_task_name_from_object(task_spec)
if task_name in final_task_dict:
raise ValueError(f"Duplicate task name: {task_name}")
final_task_dict[task_name] = task_spec
else:
# String or dict - use load_task_or_group
result = task_manager.load_task_or_group(task_spec)
# Check for duplicate names
for name in result:
if name in final_task_dict:
raise ValueError(f"Duplicate task name: {name}")
final_task_dict.update(result)
# Check for conflicting group memberships
_check_duplicates(get_subtask_list(final_task_dict))
return final_task_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment