Commit 7fcfb4ac authored by Baber's avatar Baber
Browse files

refactor: simplify docstrings and improve task name matching logic

parent 5e632643
import abc
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
...@@ -83,8 +82,14 @@ class GroupConfig(dict): ...@@ -83,8 +82,14 @@ class GroupConfig(dict):
except (TypeError, OSError): except (TypeError, OSError):
return str(value) return str(value)
@property
def version(self) -> str:
"""Returns the version of the group configuration."""
return self.metadata.get("version", "1.0")
class ConfigurableGroup(abc.ABC): @dataclass
class ConfigurableGroup:
def __init__( def __init__(
self, self,
config: Optional[dict] = None, config: Optional[dict] = None,
......
...@@ -73,19 +73,7 @@ _IGNORE_DIRS = ( ...@@ -73,19 +73,7 @@ _IGNORE_DIRS = (
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None: def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
""" """YAML constructor that ignores !function tags during simple parsing."""
YAML constructor that ignores !function tags during simple parsing.
This is used when mode="simple" to skip function resolution for
faster indexing operations.
Args:
loader: YAML loader instance
node: YAML node being processed
Returns:
None
"""
return None return None
...@@ -129,8 +117,7 @@ def _import_function(qualname: str, *, base_path: Path) -> Callable: ...@@ -129,8 +117,7 @@ def _import_function(qualname: str, *, base_path: Path) -> Callable:
Dynamically import a function from a Python module relative to base_path. Dynamically import a function from a Python module relative to base_path.
This function enables YAML files to reference Python functions using This function enables YAML files to reference Python functions using
the !function tag. It supports dot notation for nested modules and the !function tag. Supports dot notation for nested modules.
caches imported modules for performance.
Args: Args:
qualname: Qualified function name like "my_module.my_function" qualname: Qualified function name like "my_module.my_function"
...@@ -180,7 +167,7 @@ def _parse_yaml_file(path: Path, mode: str) -> dict: ...@@ -180,7 +167,7 @@ def _parse_yaml_file(path: Path, mode: str) -> dict:
@functools.lru_cache(maxsize=4096) @functools.lru_cache(maxsize=4096)
def _get_cached_config(yaml_path: Path, mode: str) -> dict: def _get_cached_config(yaml_path: Path, mode: str) -> dict:
"""Load and cache resolved YAML configs with LRU eviction.""" """Load and cache resolved YAML configs"""
# Parse the YAML file # Parse the YAML file
yaml_config = _parse_yaml_file(yaml_path, mode) yaml_config = _parse_yaml_file(yaml_path, mode)
yaml_dir = yaml_path.parent yaml_dir = yaml_path.parent
...@@ -288,7 +275,7 @@ def load_yaml_config( ...@@ -288,7 +275,7 @@ def load_yaml_config(
return final_cfg return final_cfg
def iter_yaml_files(root: Path) -> Generator[Path, Any, None]: def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, None]:
""" """
Recursively iterate over all YAML files in a directory tree. Recursively iterate over all YAML files in a directory tree.
...@@ -306,7 +293,7 @@ def iter_yaml_files(root: Path) -> Generator[Path, Any, None]: ...@@ -306,7 +293,7 @@ def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
""" """
for p in iglob("**/*.yaml", root_dir=root, recursive=True): for p in iglob("**/*.yaml", root_dir=root, recursive=True):
# ignore check # ignore check
if Path(p).parts[0] in _IGNORE_DIRS: if Path(p).parts[0] in ignore:
continue continue
yield root / p yield root / p
...@@ -352,7 +339,7 @@ class TaskManager: ...@@ -352,7 +339,7 @@ class TaskManager:
verbosity: Optional[str] = None, verbosity: Optional[str] = None,
include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None, include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
include_defaults: bool = True, include_defaults: bool = True,
metadata: Optional[dict] = None, metadata: Optional[dict[str, dict[str, Any]]] = None,
) -> None: ) -> None:
""" """
Initialize the TaskManager. Initialize the TaskManager.
...@@ -548,21 +535,7 @@ class TaskManager: ...@@ -548,21 +535,7 @@ class TaskManager:
return "".join(parts) return "".join(parts)
def match_tasks(self, task_list: list[str]) -> list[str]: def match_tasks(self, task_list: list[str]) -> list[str]:
""" """Match task names using glob-style pattern matching."""
Match task names using pattern matching.
Supports glob-style patterns and returns all matching task names.
Args:
task_list: List of task name patterns to match
Returns:
List of matching task names
Example:
>>> tm.match_tasks(["hella*", "arc_*"])
['hellaswag', 'arc_easy', 'arc_challenge']
"""
return pattern_match(task_list, self.all_tasks) return pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name: str) -> bool: def _name_is_registered(self, name: str) -> bool:
...@@ -738,276 +711,195 @@ class TaskManager: ...@@ -738,276 +711,195 @@ class TaskManager:
else False else False
) )
def _load_individual_task_or_group( ###############################################################################
self, # NEW: Refactored _load_individual_task_or_group and helper methods #
name_or_config: Optional[Union[str, dict]] = None, ###############################################################################
parent_name: Optional[str] = None,
update_config: Optional[dict] = None,
) -> Mapping:
"""
Load a single task or group with all its configurations and dependencies.
This is the core method for instantiating task objects from either task names
or configuration dictionaries. It handles complex scenarios including:
- Individual tasks and Python class-based tasks
- Groups and their constituent subtasks
- Tags and their associated tasks
- Configuration merging and inheritance
- Duplicate detection and name resolution
- Include processing and YAML inheritance
Args:
name_or_config: Either a task name (str) or configuration dict.
If str, looks up the task in the index.
If dict, processes as inline configuration.
parent_name: Name of parent group (for duplicate detection)
update_config: Additional configuration to merge into task configs
Returns:
Mapping of task/group names to instantiated task objects.
For individual tasks: {task_name: ConfigurableTask}
For groups: {group_name: {subtask1: Task1, subtask2: Task2, ...}}
Example:
Load individual task::
task_dict = tm._load_individual_task_or_group("hellaswag")
# Returns: {"hellaswag": ConfigurableTask(...)}
Load with config override::
task_dict = tm._load_individual_task_or_group(
{"task": "hellaswag", "num_fewshot": 5}
)
Load a group::
group_dict = tm._load_individual_task_or_group("arc_group") def _create_task_object(
# Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}} self,
cfg: dict,
task_name: str,
yaml_path: str | None,
) -> dict:
""" """
from lm_eval.api.task import ConfigurableTask, Task Instantiate a single task (ConfigurableTask **or** python-task) from *cfg*.
Returns {task_name: task_object}.
def _load_task(
config: dict, task: str, yaml_path: Optional[str] = None
) -> dict[str, Union["ConfigurableTask", "Task"]]:
""" """
Create a single task object from configuration. from lm_eval.api.task import ConfigurableTask, Task # local import avoids cycle
Handles include processing, Python class instantiation, and metadata injection. # ---- include handling ---------------------------------------------------
if "include" in cfg:
Args: # keep original name so include merging cannot clobber it
config: Task configuration dictionary orig_name = cfg.get("task", task_name)
task: Task name cfg = {
yaml_path: Path to source YAML file (for include resolution) **load_yaml_config( # recurse once, cached
yaml_path=Path(yaml_path) if yaml_path else None,
Returns: yaml_config={"include": cfg.pop("include")},
Dictionary mapping task name to instantiated task object
"""
if "include" in config:
# Store the task name to preserve it after include processing
original_task_name = config.get("task", task)
config = {
**load_yaml_config(
yaml_path=Path(yaml_path),
yaml_config={"include": config.pop("include")},
mode="full" if yaml_path else "simple", mode="full" if yaml_path else "simple",
), ),
**config, **cfg,
"task": original_task_name, "task": orig_name,
} }
# Ensure the task name from the group config is preserved # ---- metadata merge -----------------------------------------------------
# This prevents tasks with the same include from being treated as duplicates if self.metadata is not None:
cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
if self._config_is_python_task(config):
if self._class_has_config_in_constructor(config["class"]):
task_object = config["class"](config=config)
else: else:
task_object = config["class"]() cfg["metadata"] = cfg.get("metadata", {})
if isinstance(task_object, ConfigurableTask):
# very scuffed: set task name here. TODO: fixme? # ---- python-task vs YAML-task -------------------------------------------
task_object.config.task = task if self._config_is_python_task(cfg):
cls = cfg["class"]
task_obj: Task
if self._class_has_config_in_constructor(cls):
task_obj = cls(config=cfg)
else: else:
if self.metadata is not None: task_obj = cls()
config["metadata"] = config.get("metadata", {}) | self.metadata # make sure name propagates when the class inherits ConfigurableTask
if isinstance(task_obj, ConfigurableTask): # type: ignore
task_obj.config.task = task_name
else: else:
config["metadata"] = config.get("metadata", {}) task_obj = ConfigurableTask(config=cfg) # type: ignore
task_object = ConfigurableTask(config=config)
return {task: task_object} return {task_name: task_obj}
def _get_group_and_subtask_from_config( def _create_group_object(
config: dict, self,
) -> tuple[ConfigurableGroup, list[str]]: cfg: dict,
parent_name: str | None = None,
) -> tuple[ConfigurableGroup, list[Union[str, dict]]]:
""" """
Extract group object and subtask list from group configuration. Build ConfigurableGroup and return (group_obj, subtask_names).
Resolves tag expansion.
Expands any tags in the task list to their constituent tasks.
Args:
config: Group configuration dictionary
Returns:
Tuple of (ConfigurableGroup, list of subtask names)
""" """
if self.metadata is not None: if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | self.metadata cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
group_name = ConfigurableGroup(config=config)
subtask_list = []
for task in group_name.config["task"]:
if isinstance(task, str) and self._name_is_tag(task):
subtask_list.extend(self._get_tasklist(task))
else:
subtask_list.append(task)
return group_name, subtask_list
def _process_group_config( grp = ConfigurableGroup(config=cfg)
config: dict, update_config: Optional[dict] = None subtasks: list[Union[str, dict]] = []
) -> tuple[dict, Optional[dict]]: for t in grp.config["task"]:
""" if isinstance(t, str) and self._name_is_tag(t):
Separate group-specific config from task-level config overrides. subtasks.extend(self._get_tasklist(t))
else:
Group-only keys (like 'group', 'aggregate') stay with the group, subtasks.append(t)
while other keys become config overrides for constituent tasks. return grp, subtasks
Args: def _load_subtasks(
config: Full configuration dictionary self,
update_config: Additional config to merge subtasks: list[Union[str, dict]],
parent_name: Union[str, ConfigurableGroup, None],
update_config: dict | None,
) -> Mapping:
"""Return merged mapping of all subtasks, handling duplicates."""
fn = functools.partial(
self._load_individual_task_or_group,
parent_name=parent_name,
update_config=update_config,
)
return dict(collections.ChainMap(*map(fn, reversed(subtasks))))
Returns: def _load_individual_task_or_group(
Tuple of (group_config, task_update_config) self,
payload: str | dict,
*,
parent_name: str | None = None,
update_config: dict | None = None,
) -> Mapping:
"""
Public helper that turns *payload* (str task/group/tag **or** dict config)
into a nested Mapping of {name_or_group_obj: task_obj | sub_mapping}.
""" """
if update_config is not None:
config = {**config, **update_config}
_update_config = {
k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS
}
if not bool(_update_config):
_update_config = None
group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS}
return group_config, _update_config
if isinstance(name_or_config, str):
if update_config is not None:
# Process name_or_config as a dict instead
name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config) or self._name_is_python_task(
name_or_config
):
# Get the yaml_path for this task
yaml_path = self._get_yaml_path(name_or_config)
task_config = self._get_config(name_or_config)
# Handle task_list configs
if "task_list" in task_config:
# Find the specific task entry
task_specific_config = None
for task_entry in task_config["task_list"]:
if (
isinstance(task_entry, dict)
and task_entry.get("task") == name_or_config
):
task_specific_config = task_entry
break
if task_specific_config: # ------------------------------------------------------------------ STRING
# Create base config without task_list if isinstance(payload, str):
base_config = { # If caller supplied extra overrides, treat as dict immediately
k: v for k, v in task_config.items() if k != "task_list" if update_config:
} return self._load_individual_task_or_group(
# Merge using helper method {"task": payload, **update_config},
task_config = self._merge_task_configs( parent_name=parent_name,
base_config, task_specific_config, name_or_config
)
else:
# Task not found in task_list, shouldn't happen if indexing worked correctly
eval_logger.warning(
f"Task {name_or_config} not found in task_list"
) )
task_config = {"task": name_or_config}
return _load_task(task_config, task=name_or_config, yaml_path=yaml_path) # ------------ registered TASK (YAML or python) -----------------
else: if self._name_is_task(payload) or self._name_is_python_task(payload):
subtask_list = self._get_tasklist(name_or_config) yaml_path = self._get_yaml_path(payload)
if subtask_list == -1: cfg = self._get_config(payload)
group_config = self._get_config(name_or_config)
group_config, update_config = _process_group_config(group_config) # task_list configs: extract the per-task override ------------
group_name, subtask_list = _get_group_and_subtask_from_config( if "task_list" in cfg:
group_config override = next(
) (
else: entry
if self._name_is_tag(name_or_config): for entry in cfg["task_list"]
return self._process_tag_subtasks( if isinstance(entry, dict) and entry.get("task") == payload
name_or_config, ),
name_or_config None,
if isinstance(name_or_config, dict)
else None,
)
else:
group_name = ConfigurableGroup(
config={"group": name_or_config, "task": subtask_list}
) )
base = {k: v for k, v in cfg.items() if k != "task_list"}
if override:
cfg = {**base, **override, "task": payload}
return self._create_task_object(cfg, payload, yaml_path)
# ------------ registered GROUP ----------------------------------
if self._name_is_group(payload):
group_cfg = self._get_config(payload)
grp_only = {k: v for k, v in group_cfg.items() if k in GROUP_ONLY_KEYS}
grp_obj, subtasks = self._create_group_object(grp_only, parent_name)
return {
grp_obj: self._load_subtasks(subtasks, grp_obj, update_config=None)
}
if isinstance(name_or_config, dict): # ------------ registered TAG ------------------------------------
if self._config_is_task(name_or_config): if self._name_is_tag(payload):
name = name_or_config.pop("task") return self._process_tag_subtasks(payload, update_config=None)
if update_config is not None:
name_or_config = {**name_or_config, **update_config} raise ValueError(f"Unknown task / group / tag name: {payload!r}")
# If the name is registered as a group
if self._name_is_group(name): # ------------------------------------------------------------------- DICT
group_config = self._get_config(name) if isinstance(payload, dict):
# ------------------ simple 'task: name' dict --------------------
group_config, update_config = _process_group_config( if self._config_is_task(payload):
group_config, name_or_config name = payload["task"]
) # override existing registered YAML if exists
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
)
elif self._name_is_tag(name):
return self._process_tag_subtasks(name, name_or_config)
else:
yaml_path = None
if self._name_is_registered(name): if self._name_is_registered(name):
base_cfg = self._get_config(name)
yaml_path = self._get_yaml_path(name) yaml_path = self._get_yaml_path(name)
base_task_config = self._get_config(name) merged = {**base_cfg, **payload}
else:
merged = payload
yaml_path = None
# Check if this is a duplicate. # duplicate-naming guard when inside a group
if parent_name is not None: if parent_name is not None:
num_duplicate = len( count = len(
list( [
filter( n
lambda x: x.startswith(name), for n in self.task_group_map[parent_name]
self.task_group_map[parent_name], if n.startswith(name)
) ]
)
) )
if num_duplicate > 0: if count:
name = f"{name}-{num_duplicate}" name = f"{name}-{count}"
self.task_group_map[parent_name].append(name) self.task_group_map[parent_name].append(name)
task_config = { return self._create_task_object(merged, name, yaml_path)
**base_task_config,
**name_or_config,
}
else:
task_config = name_or_config
return _load_task(task_config, task=name, yaml_path=yaml_path)
else:
group_config, update_config = _process_group_config(name_or_config)
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
)
fn = partial( # ----------------- literal group dict (task: [...]) -------------
self._load_individual_task_or_group, if self._config_is_group(payload):
parent_name=group_name, grp_cfg = {k: v for k, v in payload.items() if k in GROUP_ONLY_KEYS}
update_config=update_config, sub_override = {
k: v for k, v in payload.items() if k not in GROUP_ONLY_KEYS
} or None
grp_obj, subtasks = self._create_group_object(grp_cfg, parent_name)
return {grp_obj: self._load_subtasks(subtasks, grp_obj, sub_override)}
# ----------------- python-task dict ('class': …) ----------------
if self._config_is_python_task(payload):
name = payload["task"]
return self._create_task_object(payload, name, yaml_path=None)
raise TypeError(
f"_load_individual_task_or_group expected str | dict, got {type(payload)}"
) )
return {
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
}
def load_task_or_group( def load_task_or_group(
self, task_list: Optional[Union[str, list[str]]] = None self, task_list: Optional[Union[str, list[str]]] = None
...@@ -1363,64 +1255,45 @@ def get_task_dict( ...@@ -1363,64 +1255,45 @@ def get_task_dict(
tm = TaskManager(include_path="/custom/tasks") tm = TaskManager(include_path="/custom/tasks")
tasks = get_task_dict(["custom_task"], task_manager=tm) tasks = get_task_dict(["custom_task"], task_manager=tm)
""" """
from lm_eval.api.task import ConfigurableTask, Task from lm_eval.api.task import Task
task_name_from_string_dict = {}
task_name_from_config_dict = {}
task_name_from_object_dict = {}
# Normalize input to list
if isinstance(task_name_list, str): if isinstance(task_name_list, str):
task_name_list = [task_name_list] task_name_list = [task_name_list]
elif isinstance(task_name_list, list): elif not isinstance(task_name_list, list):
if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]):
raise TypeError( raise TypeError(
"Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match." f"Expected a 'str' or 'list' but received {type(task_name_list)}."
) )
else:
# Validate list items
if not all(isinstance(task, (str, dict, Task)) for task in task_name_list):
raise TypeError( raise TypeError(
f"Expected a 'str' or 'list' but received {type(task_name_list)}." "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
) )
string_task_name_list = [task for task in task_name_list if isinstance(task, str)] # Ensure we have a task manager
others_task_name_list = [
task for task in task_name_list if not isinstance(task, str)
]
if len(string_task_name_list) > 0:
if task_manager is None: if task_manager is None:
task_manager = TaskManager() task_manager = TaskManager()
task_name_from_string_dict = task_manager.load_task_or_group( # Process all items
string_task_name_list final_task_dict = {}
) for task_spec in task_name_list:
if isinstance(task_spec, Task):
for task_element in others_task_name_list: # Pre-instantiated task object
if isinstance(task_element, dict): task_name = get_task_name_from_object(task_spec)
task_name_from_config_dict = { if task_name in final_task_dict:
**task_name_from_config_dict, raise ValueError(f"Duplicate task name: {task_name}")
**task_manager.load_config(config=task_element), final_task_dict[task_name] = task_spec
} else:
# String or dict - use load_task_or_group
elif isinstance(task_element, Task): result = task_manager.load_task_or_group(task_spec)
task_name_from_object_dict = { # Check for duplicate names
**task_name_from_object_dict, for name in result:
get_task_name_from_object(task_element): task_element, if name in final_task_dict:
} raise ValueError(f"Duplicate task name: {name}")
final_task_dict.update(result)
if not set(task_name_from_string_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys()) # Check for conflicting group memberships
):
raise ValueError
final_task_dict = {
**task_name_from_string_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
# behavior can get odd if one tries to invoke several groups that "compete" for the same task.
# (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask
# and we'd be unsure which to use and report.)
# we explicitly check and error in this case.
_check_duplicates(get_subtask_list(final_task_dict)) _check_duplicates(get_subtask_list(final_task_dict))
return final_task_dict return final_task_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment