Unverified Commit a57ffba1 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Merge pull request #3133 from EleutherAI/tasklist

Add `tasklist`
parents 70314843 bcd6faaa
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass, field
from inspect import getsource from inspect import getsource
from typing import Any, Callable, Optional, Union from typing import Callable, Optional, Union
@dataclass @dataclass
...@@ -22,10 +22,10 @@ class AggMetricConfig(dict): ...@@ -22,10 +22,10 @@ class AggMetricConfig(dict):
@dataclass @dataclass
class GroupConfig(dict): class GroupConfig:
group: Optional[str] = None group: Optional[str] = None
group_alias: Optional[str] = None group_alias: Optional[str] = None
task: Optional[Union[str, list]] = None task: Union[str, list] = field(default_factory=list)
aggregate_metric_list: Optional[ aggregate_metric_list: Optional[
Union[list[AggMetricConfig], AggMetricConfig, dict] Union[list[AggMetricConfig], AggMetricConfig, dict]
] = None ] = None
...@@ -40,6 +40,24 @@ class GroupConfig(dict): ...@@ -40,6 +40,24 @@ class GroupConfig(dict):
def __setitem__(self, item, value): def __setitem__(self, item, value):
return setattr(self, item, value) return setattr(self, item, value)
def __contains__(self, item):
"""Support 'in' operator for dict-like behavior."""
return hasattr(self, item)
def get(self, key, default=None):
"""Dict-like get method."""
return getattr(self, key, default)
def __hash__(self):
"""Make GroupConfig hashable based on group name."""
return hash(self.group)
def __eq__(self, other):
"""Equality comparison based on group name."""
if not isinstance(other, GroupConfig):
return False
return self.group == other.group
def __post_init__(self): def __post_init__(self):
if self.aggregate_metric_list is not None: if self.aggregate_metric_list is not None:
if isinstance(self.aggregate_metric_list, dict): if isinstance(self.aggregate_metric_list, dict):
...@@ -88,33 +106,5 @@ class GroupConfig(dict): ...@@ -88,33 +106,5 @@ class GroupConfig(dict):
except (TypeError, OSError): except (TypeError, OSError):
return str(value) return str(value)
class ConfigurableGroup:
def __init__(
self,
config: Optional[dict] = None,
) -> None:
self._config = GroupConfig(**config)
@property
def group(self):
return self._config.group
@property
def group_alias(self):
return self._config.group_alias
@property
def version(self):
return self._config.version
@property
def config(self):
return self._config.to_dict()
@property
def group_name(self) -> Any:
return self._config.group
def __repr__(self): def __repr__(self):
return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})" return f"GroupConfig(group={self.group},group_alias={self.group_alias})"
...@@ -5,7 +5,7 @@ import ast ...@@ -5,7 +5,7 @@ import ast
import logging import logging
import random import random
import re import re
from collections.abc import Callable from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy from copy import deepcopy
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, overload from typing import TYPE_CHECKING, Any, Literal, overload
...@@ -376,7 +376,6 @@ class Task(abc.ABC): ...@@ -376,7 +376,6 @@ class Task(abc.ABC):
The number of times each instance in a dataset is inferred on. Defaults to 1, The number of times each instance in a dataset is inferred on. Defaults to 1,
can be increased for techniques like majority voting. can be increased for techniques like majority voting.
""" """
pass
@abc.abstractmethod @abc.abstractmethod
def process_results(self, doc: dict, results: list) -> dict[str, Any]: def process_results(self, doc: dict, results: list) -> dict[str, Any]:
...@@ -1249,7 +1248,7 @@ class ConfigurableTask(Task): ...@@ -1249,7 +1248,7 @@ class ConfigurableTask(Task):
): # TODO: ensure that non-multimodal tasks aren't getting visual args ): # TODO: ensure that non-multimodal tasks aren't getting visual args
multimodal_arg = { multimodal_arg = {
**multimodal_arg, **multimodal_arg,
**{"visual": self.doc_to_image(doc)}, "visual": self.doc_to_image(doc),
} }
if ( if (
...@@ -1257,7 +1256,7 @@ class ConfigurableTask(Task): ...@@ -1257,7 +1256,7 @@ class ConfigurableTask(Task):
): # TODO: ensure that non-multimodal tasks aren't getting audio args ): # TODO: ensure that non-multimodal tasks aren't getting audio args
multimodal_arg = { multimodal_arg = {
**multimodal_arg, **multimodal_arg,
**{"audio": self.doc_to_audio(doc)}, "audio": self.doc_to_audio(doc),
} }
if bool(multimodal_arg): if bool(multimodal_arg):
...@@ -1543,6 +1542,8 @@ class MultipleChoiceTask(Task): ...@@ -1543,6 +1542,8 @@ class MultipleChoiceTask(Task):
} }
def aggregation(self) -> dict: def aggregation(self) -> dict:
from lm_eval.api.metrics import mean
return { return {
"acc": mean, "acc": mean,
"acc_norm": mean, "acc_norm": mean,
...@@ -1609,6 +1610,8 @@ class PerplexityTask(Task): ...@@ -1609,6 +1610,8 @@ class PerplexityTask(Task):
} }
def aggregation(self) -> dict: def aggregation(self) -> dict:
from lm_eval.api.metrics import bits_per_byte, weighted_perplexity
return { return {
"word_perplexity": weighted_perplexity, "word_perplexity": weighted_perplexity,
"byte_perplexity": weighted_perplexity, "byte_perplexity": weighted_perplexity,
......
...@@ -340,23 +340,25 @@ class EvaluatorConfig: ...@@ -340,23 +340,25 @@ class EvaluatorConfig:
metadata=self.metadata if self.metadata else {}, metadata=self.metadata if self.metadata else {},
) )
task_names = task_manager.match_tasks(self.tasks) task_names = self.tasks
# TODO: FIX TASKS VALIDATION!!!
# Check for any individual task files in the list # task_names = task_manager.match_tasks(self.tasks)
for task in [task for task in self.tasks if task not in task_names]:
task_path = Path(task) # # Check for any individual task files in the list
if task_path.is_file(): # for task in [task for task in self.tasks if task not in task_names]:
config = utils.load_yaml_config(str(task_path)) # task_path = Path(task)
task_names.append(config) # if task_path.is_file():
# config = utils.load_yaml_config(str(task_path))
# Check for missing tasks # task_names.append(config)
task_missing = [ #
task for task in self.tasks if task not in task_names and "*" not in task # # Check for missing tasks
] # task_missing = [
# task for task in self.tasks if task not in task_names and "*" not in task
if task_missing: # ]
missing = ", ".join(task_missing) #
raise ValueError(f"Tasks not found: {missing}") # if task_missing:
# missing = ", ".join(task_missing)
# raise ValueError(f"Tasks not found: {missing}")
# Update tasks with resolved names # Update tasks with resolved names
self.tasks = task_names self.tasks = task_names
......
...@@ -29,7 +29,8 @@ from lm_eval.evaluator_utils import ( ...@@ -29,7 +29,8 @@ from lm_eval.evaluator_utils import (
) )
from lm_eval.loggers import EvaluationTracker from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict from lm_eval.tasks import TaskManager
from lm_eval.tasks.manager import get_task_dict
from lm_eval.utils import ( from lm_eval.utils import (
get_logger, get_logger,
handle_non_serializable, handle_non_serializable,
......
...@@ -5,7 +5,6 @@ import pathlib ...@@ -5,7 +5,6 @@ import pathlib
import sys import sys
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
aggregate_subtask_metrics, aggregate_subtask_metrics,
mean, mean,
...@@ -153,11 +152,14 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]: ...@@ -153,11 +152,14 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]:
def get_subtask_list(task_dict, task_root=None, depth=0): def get_subtask_list(task_dict, task_root=None, depth=0):
from lm_eval.api.group import GroupConfig
from lm_eval.api.task import Task
subtask_list = {} subtask_list = {}
for group_obj, task_obj in task_dict.items(): for group_obj, task_obj in task_dict.items():
if isinstance(group_obj, ConfigurableGroup): if isinstance(group_obj, GroupConfig):
# group_name = group_obj.group_name # group_name = group_obj.group
group_name = group_obj.group_name group_name = group_obj.group
else: else:
group_name = group_obj group_name = group_obj
if isinstance(task_obj, dict): if isinstance(task_obj, dict):
...@@ -175,9 +177,9 @@ def get_subtask_list(task_dict, task_root=None, depth=0): ...@@ -175,9 +177,9 @@ def get_subtask_list(task_dict, task_root=None, depth=0):
subtask_list = {**subtask_list, **_subtask_list} subtask_list = {**subtask_list, **_subtask_list}
else: else:
if isinstance(task_obj, ConfigurableGroup): if isinstance(task_obj, GroupConfig):
# group_or_task_name = task_obj.group_name # group_or_task_name = task_obj.group
group_or_task_name = task_obj.group_name group_or_task_name = task_obj.group
elif isinstance(task_obj, Task): elif isinstance(task_obj, Task):
# group_or_task_name = task_obj.task_name # group_or_task_name = task_obj.task_name
group_or_task_name = task_obj.task_name group_or_task_name = task_obj.task_name
...@@ -224,6 +226,8 @@ def prepare_print_tasks( ...@@ -224,6 +226,8 @@ def prepare_print_tasks(
task_depth=0, task_depth=0,
group_depth=0, group_depth=0,
) -> Tuple[dict, dict]: ) -> Tuple[dict, dict]:
from lm_eval.api.task import Task
""" """
@param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its @param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
value is a list of task names. value is a list of task names.
...@@ -238,6 +242,7 @@ def prepare_print_tasks( ...@@ -238,6 +242,7 @@ def prepare_print_tasks(
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing. Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
""" """
from lm_eval.api.group import GroupConfig
def _sort_task_dict(task_dict): def _sort_task_dict(task_dict):
""" """
...@@ -248,8 +253,8 @@ def prepare_print_tasks( ...@@ -248,8 +253,8 @@ def prepare_print_tasks(
return dict( return dict(
sorted( sorted(
task_dict.items(), task_dict.items(),
key=lambda item: item[0].group_name key=lambda item: item[0].group
if isinstance(item[0], ConfigurableGroup) if isinstance(item[0], GroupConfig)
else item[0], else item[0],
) )
) )
...@@ -259,9 +264,9 @@ def prepare_print_tasks( ...@@ -259,9 +264,9 @@ def prepare_print_tasks(
task_dict = _sort_task_dict(task_dict) task_dict = _sort_task_dict(task_dict)
for task_or_group_name, task_or_group_obj in task_dict.items(): for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else "" tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup): if isinstance(task_or_group_name, GroupConfig):
# string_name = task_or_group_name.group_name # string_name = task_or_group_name.group
name = task_or_group_name.group_name name = task_or_group_name.group
from_configurable_group = True from_configurable_group = True
task_or_group_obj = _sort_task_dict(task_or_group_obj) task_or_group_obj = _sort_task_dict(task_or_group_obj)
elif isinstance(task_or_group_name, str): elif isinstance(task_or_group_name, str):
...@@ -395,6 +400,9 @@ def consolidate_group_results( ...@@ -395,6 +400,9 @@ def consolidate_group_results(
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple. The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
In the top-level invocation of this function, task_aggregation_list is ignored. In the top-level invocation of this function, task_aggregation_list is ignored.
""" """
from lm_eval.api.group import GroupConfig
from lm_eval.api.task import Task
if task_root is None: if task_root is None:
task_root = {} task_root = {}
...@@ -403,9 +411,9 @@ def consolidate_group_results( ...@@ -403,9 +411,9 @@ def consolidate_group_results(
for group_or_task, group_or_task_info in task_dict.items(): for group_or_task, group_or_task_info in task_dict.items():
# Convert to string # Convert to string
if isinstance(group_or_task, ConfigurableGroup): if isinstance(group_or_task, GroupConfig):
group_config = group_or_task.config group_config = group_or_task.to_dict()
group_or_task = group_or_task.group_name group_or_task = group_or_task.group
else: else:
group_config = None group_config = None
...@@ -434,7 +442,7 @@ def consolidate_group_results( ...@@ -434,7 +442,7 @@ def consolidate_group_results(
) )
if (group_config is None) or ( if (group_config is None) or (
group_config["aggregate_metric_list"] is None group_config.get("aggregate_metric_list") is None
): ):
results[group_or_task][" "] = " " results[group_or_task][" "] = " "
continue continue
...@@ -443,7 +451,7 @@ def consolidate_group_results( ...@@ -443,7 +451,7 @@ def consolidate_group_results(
agg_metric_list = group_config["aggregate_metric_list"] agg_metric_list = group_config["aggregate_metric_list"]
show_group_table = show_group_table | bool( show_group_table = show_group_table | bool(
group_config["aggregate_metric_list"] group_config.get("aggregate_metric_list")
) )
task_list = _task_aggregation_list[group_or_task] task_list = _task_aggregation_list[group_or_task]
......
...@@ -3,6 +3,8 @@ import logging ...@@ -3,6 +3,8 @@ import logging
import os import os
from typing import Dict from typing import Dict
import lm_eval.tasks
import lm_eval.utils
from lm_eval import utils from lm_eval import utils
...@@ -122,7 +124,7 @@ class PromptString: ...@@ -122,7 +124,7 @@ class PromptString:
if "doc_to_choice" in self.prompt_string: if "doc_to_choice" in self.prompt_string:
raise NotImplementedError("Not yet implemented to accept doc_to_choice") raise NotImplementedError("Not yet implemented to accept doc_to_choice")
text_string = utils.apply_template(doc_to_text, doc) text_string = lm_eval.utils.apply_template(doc_to_text, doc)
target_string = utils.apply_template(doc_to_target, doc) target_string = lm_eval.utils.apply_template(doc_to_target, doc)
return [text_string, target_string] return [text_string, target_string]
import collections from .manager import TaskManager
import inspect
import logging
import os
from functools import partial
from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils
from lm_eval.api.group import ConfigurableGroup, GroupConfig
from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.evaluator_utils import get_subtask_list
__all__ = ["TaskManager"]
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
eval_logger = logging.getLogger(__name__)
class TaskManager:
"""TaskManager indexes all tasks from the default `lm_eval/tasks/`
and an optional directory if provided.
"""
def __init__(
self,
verbosity: Optional[str] = None,
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
metadata: Optional[dict] = None,
) -> None:
if verbosity is not None:
utils.get_logger(verbosity)
self.include_path = include_path
self.metadata = metadata
self._task_index = self.initialize_tasks(
include_path=include_path, include_defaults=include_defaults
)
self._all_tasks = sorted(list(self._task_index.keys()))
self._all_groups = sorted(
[x for x in self._all_tasks if self._task_index[x]["type"] == "group"]
)
self._all_subtasks = sorted(
[
x
for x in self._all_tasks
if self._task_index[x]["type"] in ["task", "python_task"]
]
)
self._all_tags = sorted(
[x for x in self._all_tasks if self._task_index[x]["type"] == "tag"]
)
self.task_group_map = collections.defaultdict(list)
def initialize_tasks(
self,
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
) -> dict[str, dict]:
"""Creates a dictionary of tasks indexes.
:param include_path: Union[str, List] = None
An additional path to be searched for tasks recursively.
Can provide more than one such path as a list.
:param include_defaults: bool = True
If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
return
Dictionary of task names as key and task metadata
"""
if include_defaults:
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
else:
all_paths = []
if include_path is not None:
if isinstance(include_path, str):
include_path = [include_path]
all_paths.extend(include_path)
task_index = {}
for task_dir in all_paths:
tasks = self._get_task_and_group(task_dir)
task_index = {**task_index, **tasks}
return task_index
@property
def all_tasks(self):
return self._all_tasks
@property
def all_groups(self):
return self._all_groups
@property
def all_subtasks(self):
return self._all_subtasks
@property
def all_tags(self):
return self._all_tags
@property
def task_index(self):
return self._task_index
def list_all_tasks(
self, list_groups=True, list_tags=True, list_subtasks=True
) -> str:
from pytablewriter import MarkdownTableWriter
def sanitize_path(path):
# don't print full path if we are within the lm_eval/tasks dir !
# if we aren't though, provide the full path.
if "lm_eval/tasks/" in path:
return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1]
else:
return path
group_table = MarkdownTableWriter()
group_table.headers = ["Group", "Config Location"]
gt_values = []
for g in self.all_groups:
path = self.task_index[g]["yaml_path"]
if path == -1:
path = "---"
else:
path = sanitize_path(path)
gt_values.append([g, path])
group_table.value_matrix = gt_values
tag_table = MarkdownTableWriter()
tag_table.headers = ["Tag"]
tag_table.value_matrix = [[t] for t in self.all_tags]
subtask_table = MarkdownTableWriter()
subtask_table.headers = ["Task", "Config Location", "Output Type"]
st_values = []
for t in self.all_subtasks:
path = self.task_index[t]["yaml_path"]
output_type = ""
# read the yaml file to determine the output type
if path != -1:
config = utils.load_yaml_config(path, mode="simple")
if "output_type" in config:
output_type = config["output_type"]
elif (
"include" in config
): # if no output type, check if there is an include with an output type
include_path = path.split("/")[:-1] + config["include"]
include_config = utils.load_yaml_config(include_path, mode="simple")
if "output_type" in include_config:
output_type = include_config["output_type"]
if path == -1:
path = "---"
else:
path = sanitize_path(path)
st_values.append([t, path, output_type])
subtask_table.value_matrix = st_values
result = "\n"
if list_groups:
result += group_table.dumps() + "\n\n"
if list_tags:
result += tag_table.dumps() + "\n\n"
if list_subtasks:
result += subtask_table.dumps() + "\n\n"
return result
def match_tasks(self, task_list: list[str]) -> list[str]:
return utils.pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name: str) -> bool:
if name in self.all_tasks:
return True
return False
def _name_is_task(self, name: str) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"):
return True
return False
def _name_is_tag(self, name: str) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"):
return True
return False
def _name_is_group(self, name: str) -> bool:
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "group"
):
return True
return False
def _name_is_python_task(self, name: str) -> bool:
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "python_task"
):
return True
return False
def _config_is_task(self, config: dict) -> bool:
if ("task" in config) and isinstance(config["task"], str):
return True
return False
def _config_is_group(self, config: dict) -> bool:
if ("task" in config) and isinstance(config["task"], list):
return True
return False
def _config_is_python_task(self, config: dict) -> bool:
if "class" in config:
return True
return False
def _get_yaml_path(self, name: str):
if name not in self.task_index:
raise ValueError
return self.task_index[name]["yaml_path"]
def _get_config(self, name):
if name not in self.task_index:
raise ValueError
yaml_path = self._get_yaml_path(name)
if yaml_path == -1:
return {}
else:
return utils.load_yaml_config(yaml_path, mode="full")
def _get_tasklist(self, name):
if self._name_is_task(name):
raise ValueError
return self.task_index[name]["task"]
def _process_alias(self, config, group=None):
# If the group is not the same as the original
# group which the group alias was intended for,
# Set the group_alias to None instead.
if ("group_alias" in config) and ("group" in config) and group is not None:
if config["group"] != group:
config["group_alias"] = None
return config
def _class_has_config_in_constructor(self, cls):
constructor = getattr(cls, "__init__", None)
return (
"config" in inspect.signature(constructor).parameters
if constructor
else False
)
def _load_individual_task_or_group(
self,
name_or_config: Optional[Union[str, dict]] = None,
parent_name: Optional[str] = None,
update_config: Optional[dict] = None,
) -> Mapping:
def _load_task(config, task):
if "include" in config:
config = {
**utils.load_yaml_config(
yaml_path=None,
yaml_config={"include": config.pop("include")},
mode="full",
),
**config,
}
if self._config_is_python_task(config):
if self._class_has_config_in_constructor(config["class"]):
task_object = config["class"](config=config)
else:
task_object = config["class"]()
if isinstance(task_object, ConfigurableTask):
# very scuffed: set task name here. TODO: fixme?
task_object.config.task = task
else:
if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | self.metadata
else:
config["metadata"] = config.get("metadata", {})
task_object = ConfigurableTask(config=config)
return {task: task_object}
def _get_group_and_subtask_from_config(
config: dict,
) -> tuple[ConfigurableGroup, list[str]]:
if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | self.metadata
group_name = ConfigurableGroup(config=config)
subtask_list = []
for task in group_name.config["task"]:
if isinstance(task, str) and self._name_is_tag(task):
subtask_list.extend(self._get_tasklist(task))
else:
subtask_list.append(task)
return group_name, subtask_list
def _process_group_config(
config: dict, update_config: dict = None
) -> tuple[dict, dict]:
if update_config is not None:
config = {**config, **update_config}
_update_config = {
k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS
}
if not bool(_update_config):
_update_config = None
group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS}
return group_config, _update_config
if isinstance(name_or_config, str):
if update_config is not None:
# Process name_or_config as a dict instead
name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config) or self._name_is_python_task(
name_or_config
):
task_config = self._get_config(name_or_config)
return _load_task(task_config, task=name_or_config)
else:
subtask_list = self._get_tasklist(name_or_config)
if subtask_list == -1:
group_config = self._get_config(name_or_config)
group_config, update_config = _process_group_config(group_config)
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
)
else:
if self._name_is_tag(name_or_config):
fn = partial(
self._load_individual_task_or_group,
update_config=name_or_config
if isinstance(name_or_config, dict)
else None,
)
return dict(
collections.ChainMap(*map(fn, reversed(subtask_list)))
)
else:
group_name = ConfigurableGroup(
config={"group": name_or_config, "task": subtask_list}
)
if isinstance(name_or_config, dict):
if self._config_is_task(name_or_config):
name = name_or_config.pop("task")
if update_config is not None:
name_or_config = {**name_or_config, **update_config}
# If the name is registered as a group
if self._name_is_group(name):
group_config = self._get_config(name)
group_config, update_config = _process_group_config(
group_config, name_or_config
)
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
)
elif self._name_is_tag(name):
subtask_list = self._get_tasklist(name)
fn = partial(
self._load_individual_task_or_group,
update_config=name_or_config,
)
return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
else:
if self._name_is_registered(name):
base_task_config = self._get_config(name)
# Check if this is a duplicate.
if parent_name is not None:
num_duplicate = len(
list(
filter(
lambda x: x.startswith(name),
self.task_group_map[parent_name],
)
)
)
if num_duplicate > 0:
name = f"{name}-{num_duplicate}"
self.task_group_map[parent_name].append(name)
task_config = {
**base_task_config,
**name_or_config,
}
else:
task_config = name_or_config
return _load_task(task_config, task=name)
else:
group_config, update_config = _process_group_config(name_or_config)
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
)
fn = partial(
self._load_individual_task_or_group,
parent_name=group_name,
update_config=update_config,
)
return {
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
}
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
"""Loads a dictionary of task objects from a list
:param task_list: Union[str, list] = None
Single string or list of string of task names to be loaded
:return
Dictionary of task objects
"""
if isinstance(task_list, str):
task_list = [task_list]
all_loaded_tasks = dict(
collections.ChainMap(
*map(
lambda task: self._load_individual_task_or_group(task),
task_list,
)
)
)
return all_loaded_tasks
def load_config(self, config: Dict):
return self._load_individual_task_or_group(config)
def _get_task_and_group(self, task_dir: str):
"""Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, `group` or `tags`.
`task` refer to regular task configs, `python_task` are special
yaml files that only consists of `task` and `class` parameters.
`group` are group configs. `tags` are labels that can be assigned
to tasks to assist in sorting and calling tasks of certain themes.
- `yaml_path`, path to the yaml file. If the entry is a `group` that
was configured through a task config, the yaml_path will be -1
and all subtasks will be listed in `task` (see below)
- `task`, reserved for entries with `type` as `group`. This will list
all subtasks. When a group config is created (as opposed to task
config having `group` parameter set), this will be set to -1 to
avoid recursive indexing. The whole list of subtasks will be loaded
at evaluation.
:param task_dir: str
A directory to check for tasks
:return
Dictionary of task names as key and task metadata
"""
def _populate_tags_and_groups(config, task, tasks_and_groups, print_info):
# TODO: remove group in next release
if "tag" in config:
attr_list = config["tag"]
if isinstance(attr_list, str):
attr_list = [attr_list]
for tag in attr_list:
if tag not in tasks_and_groups:
tasks_and_groups[tag] = {
"type": "tag",
"task": [task],
"yaml_path": -1,
}
elif tasks_and_groups[tag]["type"] != "tag":
eval_logger.info(
f"The tag '{tag}' is already registered as a group, this tag will not be registered. "
"This may affect tasks you want to call."
)
break
else:
tasks_and_groups[tag]["task"].append(task)
# TODO: remove group in next release
print_info = True
ignore_dirs = [
"__pycache__",
".ipynb_checkpoints",
]
tasks_and_groups = collections.defaultdict()
for root, dirs, file_list in os.walk(task_dir):
dirs[:] = [d for d in dirs if d not in ignore_dirs]
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
config = utils.load_yaml_config(yaml_path, mode="simple")
if self._config_is_python_task(config):
# This is a python class config
task = config["task"]
tasks_and_groups[task] = {
"type": "python_task",
"yaml_path": yaml_path,
}
_populate_tags_and_groups(
config, task, tasks_and_groups, print_info
)
elif self._config_is_group(config):
# This is a group config
tasks_and_groups[config["group"]] = {
"type": "group",
"task": -1, # This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"yaml_path": yaml_path,
}
# # Registered the level 1 tasks from a group config
# for config in config["task"]:
# if isinstance(config, dict) and self._config_is_task(config):
# task = config["task"]
# tasks_and_groups[task] = {
# "type": "task",
# "yaml_path": yaml_path,
# }
elif self._config_is_task(config):
# This is a task config
task = config["task"]
tasks_and_groups[task] = {
"type": "task",
"yaml_path": yaml_path,
}
_populate_tags_and_groups(
config, task, tasks_and_groups, print_info
)
else:
eval_logger.debug(f"File {f} in {root} could not be loaded")
return tasks_and_groups
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "task" in task_config:
return task_config["task"]
if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config)
else:
return "{dataset_path}".format(**task_config)
def get_task_name_from_object(task_object):
if hasattr(task_object, "config"):
return task_object._config["task"]
# TODO: scrap this
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def _check_duplicates(task_dict: dict) -> None:
"""helper function solely used in validating get_task_dict output.
Takes the output of lm_eval.evaluator_utils.get_subtask_list and
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
"oversubscribed" to several disjoint groups.
"""
subtask_names = []
for key, value in task_dict.items():
subtask_names.extend(value)
duplicate_tasks = {
task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
}
# locate the potentially problematic groups that seem to 'compete' for constituent subtasks
competing_groups = [
group
for group in task_dict.keys()
if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
]
if len(duplicate_tasks) > 0:
raise ValueError(
f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
)
def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, Task]]],
task_manager: Optional[TaskManager] = None,
):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]]
Name of model or LM object, see lm_eval.models.get_model
:param task_manager: TaskManager = None
A TaskManager object that stores indexed tasks. If not set,
task_manager will load one. This should be set by the user
if there are additional paths that want to be included
via `include_path`
:return
Dictionary of task objects
"""
task_name_from_string_dict = {}
task_name_from_config_dict = {}
task_name_from_object_dict = {}
if isinstance(task_name_list, str):
task_name_list = [task_name_list]
elif isinstance(task_name_list, list):
if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]):
raise TypeError(
"Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
)
else:
raise TypeError(
f"Expected a 'str' or 'list' but received {type(task_name_list)}."
)
string_task_name_list = [task for task in task_name_list if isinstance(task, str)]
others_task_name_list = [
task for task in task_name_list if not isinstance(task, str)
]
if len(string_task_name_list) > 0:
if task_manager is None:
task_manager = TaskManager()
task_name_from_string_dict = task_manager.load_task_or_group(
string_task_name_list
)
for task_element in others_task_name_list:
if isinstance(task_element, dict):
task_name_from_config_dict = {
**task_name_from_config_dict,
**task_manager.load_config(config=task_element),
}
elif isinstance(task_element, Task):
task_name_from_object_dict = {
**task_name_from_object_dict,
get_task_name_from_object(task_element): task_element,
}
if not set(task_name_from_string_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys())
):
raise ValueError
final_task_dict = {
**task_name_from_string_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
# behavior can get odd if one tries to invoke several groups that "compete" for the same task.
# (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask
# and we'd be unsure which to use and report.)
# we explicitly check and error in this case.
_check_duplicates(get_subtask_list(final_task_dict))
return final_task_dict
from __future__ import annotations
import importlib.util
import sys
from pathlib import Path
from typing import Any
import yaml
_Base = (
yaml.CSafeLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
)
_IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"}
def _mk_function_ctor(base_dir: Path, resolve: bool):
def ctor(loader: yaml.Loader, node: yaml.Node):
spec = loader.construct_scalar(node) # type: ignore[arg-type]
if not resolve:
return str(base_dir.expanduser() / spec)
return _import_func_in_yml(spec, base_dir)
return ctor
def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]:
class Loader(_Base): ... # type: ignore[no-redef]
yaml.add_constructor(
"!function",
_mk_function_ctor(base_dir, resolve_funcs),
Loader=Loader,
)
return Loader
def _load_module_with_cache(module_path: Path) -> Any:
"""Load a module from a file path with caching and hot-reload support.
Args:
module_path: Path to the Python file to load
Returns:
The loaded module
"""
# Determine module name based on location
path_str = str(module_path)
# Check if this is a built-in task module
if "/lm_eval/tasks/" in path_str:
# Find the position of lm_eval/tasks/ in the path
tasks_idx = path_str.find("/lm_eval/tasks/")
if tasks_idx != -1:
# Extract path starting from lm_eval/tasks/
# e.g., /path/to/lm_eval/tasks/hellaswag/utils.py → hellaswag/utils.py
relative_path = path_str[tasks_idx + len("/lm_eval/tasks/") :]
# Remove .py and convert to module name
# e.g., hellaswag/utils.py → lm_eval.tasks.hellaswag.utils
module_parts = relative_path.replace(".py", "").replace("/", ".")
module_name = f"lm_eval.tasks.{module_parts}"
else:
# Fallback to full path if pattern not found
module_name = str(module_path.with_suffix(""))
else:
# External module - use full path without extension
module_name = str(module_path.with_suffix(""))
# Check if we need to reload the module
if module_name in sys.modules:
existing_module = sys.modules[module_name]
# Check if it was modified
current_mtime = module_path.stat().st_mtime_ns
if (
hasattr(existing_module, "__mtime__")
and existing_module.__mtime__ == current_mtime
):
# Module hasn't changed, reuse it
return existing_module
# Load or reload the module
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load module from {module_path}") from None
module = importlib.util.module_from_spec(spec)
# Store mtime for future checks
module.__mtime__ = module_path.stat().st_mtime_ns
spec.loader.exec_module(module) # type: ignore[arg-type]
sys.modules[module_name] = module
return module
def _import_func_in_yml(qual: str, base_dir: Path):
"""Import function from qual: utils.process_doc, checking local files first then standard imports.
Args:
qual: Qualified function name (e.g., 'utils.process_doc')
base_dir: Directory to search for local modules
"""
mod_path, _, fn_name = qual.rpartition(".")
# 1) relative "utils.py" next to YAML
rel = (base_dir / f"{mod_path.replace('.', '/')}.py").resolve()
if rel.exists():
module = _load_module_with_cache(rel)
return getattr(module, fn_name)
# 2) already-importable module
module = __import__(mod_path, fromlist=[fn_name])
return getattr(module, fn_name)
def _import_fun_from_str(path_str: str) -> Any:
"""Import a function from a string in the form '/absolute/path/to/module.function_name'."""
try:
# Split off the function name from the rightmost dot
module_path_str, function_name = path_str.rsplit(".", 1)
except ValueError as e:
raise ValueError(
f"Invalid path format: {path_str}. Expected format: /path/to/module.function_name"
) from e
# Convert to Path and handle .py extension
module_path = Path(module_path_str)
if not module_path.suffix:
module_path = module_path.with_suffix(".py")
elif module_path.suffix != ".py":
# If it has a non-.py suffix, the user might have included .py in the path
# e.g., "/path/to/module.py.function_name"
base_path = module_path.with_suffix("")
if base_path.with_suffix(".py").exists():
module_path = base_path.with_suffix(".py")
if not module_path.exists():
raise ImportError(f"Module file not found: {module_path}")
module = _load_module_with_cache(module_path)
if not hasattr(module, function_name):
raise AttributeError(
f"Function '{function_name}' not found in module {module_path}"
)
return getattr(module, function_name)
def load_yaml(
path: str | Path,
*,
resolve_func: bool = True,
recursive: bool = True,
_seen: set[Path] | None = None,
) -> dict[str, Any]:
"""Pure data-loading helper.
Returns a dict ready for higher-level interpretation.
•No task/group/tag semantics here.
"""
path = Path(path).expanduser().resolve()
if _seen is None:
_seen = set()
if path in _seen:
raise ValueError(f"Include cycle at {path}")
_seen.add(path)
loader_cls = _make_loader(path.parent, resolve_funcs=resolve_func)
with path.open("rb") as fh:
cfg = yaml.load(fh, Loader=loader_cls)
if not recursive or "include" not in cfg:
return cfg
else:
includes = cfg.pop("include")
merged = {}
for inc in includes if isinstance(includes, list) else [includes]:
inc_path = (path.parent / inc) if not Path(inc).is_absolute() else Path(inc)
merged.update(
load_yaml(
inc_path,
resolve_func=resolve_func,
recursive=True,
_seen=_seen,
),
)
merged.update(cfg) # local keys win
return merged
from __future__ import annotations
import inspect
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
from lm_eval.api.group import GroupConfig
from lm_eval.api.task import ConfigurableTask, Task # noqa: F401 (typing)
from lm_eval.tasks._config_loader import load_yaml as load_cfg
from lm_eval.tasks.index import Entry, Kind
load_cfg_cached = load_cfg # type: ignore[no-redef]
class TaskFactory:
"""
Turns a *Entry* (plus optional overrides) into a
*Task* | *ConfigurableTask* | *GroupConfig* hierarchy.
"""
def __init__(self, *, meta: dict[str, Any] | None = None):
self._meta = meta or {}
# ---------------------------------------------------------------- public API
def build(
self,
entry: Entry,
*,
overrides: dict[str, Any] | None = None,
registry: Mapping[str, Entry],
):
"""
• entry.kind == TASK / PY_TASK ➜ returns instantiated task object
• entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks)
• entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion)
"""
if entry.kind is Kind.TAG:
return self._build_tag(entry, overrides, registry)
if entry.kind is Kind.GROUP:
return self._build_group(entry, overrides, registry)
return self._build_task(entry, overrides)
def _build_task(self, entry: Entry, overrides: dict[str, Any] | None):
cfg = self._load_full_config(entry, overrides)
if "class" in cfg: # PY_TASK route
cls = cfg["class"]
obj = cls(config=cfg) if _ctor_accepts_config(cls) else cls()
if isinstance(obj, ConfigurableTask):
obj.config.task = entry.name
return obj
# YAML task
return ConfigurableTask(config=cfg) # type: ignore[arg-type]
def _build_group(
self,
entry: Entry,
overrides: dict[str, Any] | None,
registry: Mapping[str, Entry],
):
raw_cfg = self._load_full_config(entry, None)
grp_cfg = {k: v for k, v in raw_cfg.items() if k in GroupConfig.__annotations__}
grp_cfg["metadata"] = grp_cfg.get("metadata", {}) | self._meta
group_obj = GroupConfig(**grp_cfg)
children: dict[str, Any] = {}
for item in group_obj.task:
if isinstance(item, str): # task: hellaswag
child = self.build(
registry[item],
overrides=overrides, # group-level overrides propagate
registry=registry,
)
elif isinstance(item, dict): # task: {task: hellaswag, num_fewshot: 5}
base_name = item["task"]
child = self.build(
registry[base_name],
overrides=item, # per-item override
registry=registry,
)
else:
raise TypeError(
f"Unsupported sub-entry {item!r} in group '{entry.name}'"
)
# `child` itself is a mapping (task-name -> obj) or {GroupConfig: ...}
children.update(child)
return {group_obj: children}
def _build_tag(
self,
entry: Entry,
overrides: dict[str, Any] | None,
registry: Mapping[str, Entry],
):
return {
name: self._build_task(registry[name], overrides) for name in entry.tags
}
def _load_full_config(
self, entry: Entry, overrides: dict[str, Any] | None
) -> dict[str, Any]:
if entry.yaml_path:
cfg = deepcopy(load_cfg_cached(entry.yaml_path, resolve_func=True))
else:
cfg = {"metadata": {"config": "unknown"}} # python task without YAML
if overrides:
cfg = {**cfg, **overrides}
cfg["metadata"] = (
m if isinstance(m := cfg.get("metadata", {}), dict) else {"_metadata": m}
) | self._meta
cfg.setdefault("task", entry.name)
return cfg
def _ctor_accepts_config(cls) -> bool:
init = getattr(cls, "__init__", None)
return init and "config" in inspect.signature(init).parameters
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import TYPE_CHECKING, Any
from lm_eval.tasks._config_loader import load_yaml as load_cfg
if TYPE_CHECKING:
from collections.abc import Iterable
from pathlib import Path
class Kind(Enum):
TASK = auto() # YAML task, or task_list entry
PY_TASK = auto() # Python-defined, via "class"
GROUP = auto()
TAG = auto()
TASK_LIST = auto()
@dataclass
class Entry:
name: str
kind: Kind
yaml_path: Path | None # None for generated / py-only entries
cfg: dict[str, str] | None = None
tags: set[str] = field(default_factory=set)
task_list_path: Path | None = None
log = logging.getLogger(__name__)
_IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"}
class TaskIndex:
"""Walks one or more directories, parses YAML quickly (functions unresolved),
and produces a mapping {task_name: Entry}.
"""
def __init__(self, *, meta: dict[str, str] | None = None) -> None:
self._metadata = meta or {}
def build(
self,
paths: Iterable[Path],
*,
resolve_includes=False,
) -> dict[str, Entry]:
index: dict[str, Entry] = {}
log.debug("Building task index from %s", paths)
for root in paths:
for yaml_path in self._iter_yaml_files(root):
try:
cfg = load_cfg(
yaml_path,
resolve_func=False,
recursive=resolve_includes,
)
self.process_cfg(cfg, yaml_path, index)
except Exception as err:
log.debug("Skip %s (%s)", yaml_path, err)
continue
# self._process_cfg(cfg, yaml_path, index)
log.debug("Built task index with %d entries", len(index))
return index
@staticmethod
def _iter_yaml_files(root: Path):
yield from (
p
for p in root.glob("**/*.yaml")
if not any(part in _IGNORE_DIRS for part in p.parts)
)
@staticmethod
def process_cfg(
cfg: dict[str, Any],
path: Path,
index: dict[str, Entry],
) -> None:
kind = TaskIndex._kind_of(cfg)
if kind is Kind.GROUP:
grp_name = cfg["group"]
index[grp_name] = Entry(
name=grp_name,
kind=Kind.GROUP,
yaml_path=path,
tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
)
return
if kind is Kind.PY_TASK:
name = cfg["task"]
index[name] = Entry(
name=name,
kind=Kind.PY_TASK,
yaml_path=None,
tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
)
TaskIndex._register_tags(name, cfg.get("tag"), index)
return
if kind is Kind.TASK:
name = cfg["task"]
index[name] = Entry(
name=name,
kind=Kind.TASK,
yaml_path=path,
tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
)
TaskIndex._register_tags(name, cfg.get("tag"), index)
return
if kind is Kind.TASK_LIST:
for entry in cfg["task_list"]:
task_name = entry["task"] if isinstance(entry, dict) else entry
index[task_name] = Entry(
name=task_name,
kind=Kind.TASK,
yaml_path=path,
tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
)
TaskIndex._register_tags(task_name, entry.get("tag"), index)
return
@staticmethod
def _register_tags(
task: str,
tags: str | list[str] | None,
index: dict[str, Entry],
) -> None:
if not tags:
return
for tag in tags if isinstance(tags, list) else [tags]:
entry = index.setdefault(
tag,
Entry(name=tag, kind=Kind.TAG, yaml_path=None, tags=set()),
)
entry.tags.add(task)
@staticmethod
def _kind_of(cfg: dict) -> Kind:
if "class" in cfg:
return Kind.PY_TASK
if "group" in cfg:
return Kind.GROUP
if "task_list" in cfg:
return Kind.TASK_LIST
if "task" in cfg:
return Kind.GROUP if isinstance(cfg["task"], list) else Kind.TASK
msg = "Unknown config shape"
raise ValueError(msg) from None
@staticmethod
def _str_to_set(tags: str | list[str] | None = None) -> set[str]:
"""Convert a string or list of strings to a set of strings."""
return (
set(tags)
if isinstance(tags, list)
else {tags}
if isinstance(tags, str)
else set()
)
from __future__ import annotations
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any
from lm_eval.api.task import Task
from lm_eval.tasks.factory import TaskFactory
from lm_eval.tasks.index import Entry, Kind, TaskIndex
from lm_eval.utils import setup_logging
if TYPE_CHECKING:
from lm_eval.api.task import Task
class TaskManager:
def __init__(
self,
verbosity: str | None = None,
include_path: str | Path | list[str | Path] | None = None,
include_defaults: bool = True,
metadata: dict[str, dict[str, Any]] | None = None,
) -> None:
if verbosity:
setup_logging(verbosity)
index = TaskIndex()
self._factory = TaskFactory(meta=metadata)
all_paths: list[Path] = []
if include_defaults:
all_paths.append(Path(__file__).parent)
if include_path:
all_paths += [
Path(p)
for p in (
include_path
if isinstance(include_path, (list, tuple))
else [include_path]
)
]
self._index = index.build(all_paths)
buckets = defaultdict(list)
for k, e in self._index.items():
buckets[e.kind].append(k)
self._all_tasks = sorted(
chain.from_iterable(buckets[k] for k in {Kind.TASK, Kind.PY_TASK})
)
self._all_groups = sorted(buckets[Kind.GROUP])
self._all_tags = sorted(buckets[Kind.TAG])
def _entry(self, name: str) -> Entry:
if name not in self._index:
raise KeyError(f"Unknown task/group/tag: {name}")
return self._index[name]
def load_spec(self, spec: str | dict[str, Any]):
"""Spec can be:
• str task / group / tag name (registered)
• dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5}
"""
if isinstance(spec, str):
entry = self._entry(spec)
return self._factory.build(entry, overrides=None, registry=self._index)
if isinstance(spec, dict):
# inline dict => find base entry, then pass overrides
name = spec["task"]
entry = self._entry(name)
return self._factory.build(entry, overrides=spec, registry=self._index)
raise TypeError("spec must be str or dict")
def load_task_or_group(self, task_list: str | list[str]):
return (
[self.load_spec(s) for s in task_list]
if isinstance(task_list, list)
else [self.load_spec(task_list)]
)
def get_task_dict(
task_name_list: str | list[str | dict | Task],
task_manager: TaskManager | None = None,
):
if not task_manager:
task_manager = TaskManager()
else:
assert isinstance(task_manager, TaskManager)
return {
task_name: task_manager.load_spec(task_name)
if isinstance(task_name, str)
else task_name
for task_name in task_name_list
}
...@@ -15,11 +15,10 @@ from dataclasses import asdict, is_dataclass ...@@ -15,11 +15,10 @@ from dataclasses import asdict, is_dataclass
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional
import numpy as np import numpy as np
import yaml from jinja2 import BaseLoader, Environment, StrictUndefined
from jinja2 import BaseLoader, Environment, StrictUndefined, Template
SPACING = " " * 47 SPACING = " " * 47
...@@ -117,8 +116,7 @@ def setup_logging(verbosity=logging.INFO, suppress_third_party=True): ...@@ -117,8 +116,7 @@ def setup_logging(verbosity=logging.INFO, suppress_third_party=True):
# Configure custom formatter # Configure custom formatter
class CustomFormatter(logging.Formatter): class CustomFormatter(logging.Formatter):
def format(self, record): def format(self, record):
if record.name.startswith("lm_eval."): record.name = record.name.removeprefix("lm_eval.")
record.name = record.name[len("lm_eval.") :]
return super().format(record) return super().format(record)
formatter = CustomFormatter( formatter = CustomFormatter(
...@@ -527,105 +525,6 @@ def positional_deprecated(fn): ...@@ -527,105 +525,6 @@ def positional_deprecated(fn):
return _wrapper return _wrapper
def ignore_constructor(loader, node):
return node
def import_function(loader: yaml.Loader, node, yaml_path: Path):
function_name = loader.construct_scalar(node)
*module_name, function_name = function_name.split(".")
if isinstance(module_name, list):
module_name = ".".join(module_name)
module_path = yaml_path.parent / f"{module_name}.py"
spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix())
if spec is None:
raise ImportError(f"Could not import module {module_name} from {module_path}.")
module = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module loader is None, {module_name} from {module_path}.")
spec.loader.exec_module(module)
function = getattr(module, function_name)
return function
def load_yaml_config(
yaml_path: str | None = None, yaml_config=None, yaml_dir=None, mode="full"
):
if mode == "simple":
constructor_fn = ignore_constructor
elif mode == "full":
if yaml_path is None:
raise ValueError("yaml_path must be provided if mode is 'full'.")
# Attach yaml_path to the import function so that it can be used later
constructor_fn = partial(import_function, yaml_path=Path(yaml_path))
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
# Add the import_function constructor to the YAML loader
yaml.add_constructor("!function", constructor_fn, Loader=loader)
if yaml_config is None:
with open(yaml_path, "rb") as file:
yaml_config = yaml.load(file, Loader=loader)
if yaml_dir is None:
yaml_dir = os.path.dirname(yaml_path)
assert yaml_dir is not None
if "include" in yaml_config:
include_path = yaml_config["include"]
del yaml_config["include"]
if isinstance(include_path, str):
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not os.path.isfile(path):
path = os.path.join(yaml_dir, path)
try:
included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
final_yaml_config.update(included_yaml_config)
except Exception as ex:
# If failed to load, ignore
raise ex
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
env = Environment(
loader=BaseLoader(), undefined=StrictUndefined, keep_trailing_newline=True
)
env.filters["regex_replace"] = regex_replace
@lru_cache(maxsize=128)
def _compile(raw: str) -> Template:
return env.from_string(raw)
def apply_template(template: str, doc: dict) -> str:
rtemplate = _compile(template)
return rtemplate.render(**doc)
def create_iterator( def create_iterator(
raw_iterator: collections.Iterator, raw_iterator: collections.Iterator,
*, *,
...@@ -705,3 +604,25 @@ def hash_dict_images(data_dict): ...@@ -705,3 +604,25 @@ def hash_dict_images(data_dict):
if importlib.util.find_spec("PIL") if importlib.util.find_spec("PIL")
else data_dict else data_dict
) )
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
@functools.lru_cache(maxsize=256)
def _compile_tpl(src: str):
return apply_template._env.from_string(src)
def apply_template(template: str, doc: dict) -> str:
if not hasattr(apply_template, "_env"):
apply_template._env = Environment(
loader=BaseLoader(),
undefined=StrictUndefined,
keep_trailing_newline=True,
)
apply_template._env.filters["regex_replace"] = regex_replace
return _compile_tpl(template).render(**doc)
...@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" ...@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
name = "lm_eval" name = "lm_eval"
version = "0.4.9.1" version = "0.4.9.1"
authors = [ authors = [
{name="EleutherAI", email="contact@eleuther.ai"} { name = "EleutherAI", email = "contact@eleuther.ai" }
] ]
description = "A framework for evaluating language models" description = "A framework for evaluating language models"
readme = "README.md" readme = "README.md"
...@@ -19,25 +19,22 @@ classifiers = [ ...@@ -19,25 +19,22 @@ classifiers = [
requires-python = ">=3.9" requires-python = ">=3.9"
license = { "text" = "MIT" } license = { "text" = "MIT" }
dependencies = [ dependencies = [
"accelerate>=0.26.0", "accelerate>=0.26.0",
"datasets>=2.16.0,<4.0", "datasets>=2.16.0,<4.0",
"evaluate>=0.4.0", "evaluate>=0.4.0",
"peft>=0.2.0", "peft>=0.2.0",
"pytablewriter", "pytablewriter",
"rouge-score>=0.0.4", "rouge-score>=0.0.4",
"sacrebleu>=1.5.0", "sacrebleu>=1.5.0",
"scikit-learn>=0.24.1", "scikit-learn>=0.24.1",
"sqlitedict", "sqlitedict",
"torch>=1.8", "torch>=1.8",
"transformers>=4.1", "transformers>=4.1",
"dill", "dill",
"word2number", "word2number",
"more_itertools" "more_itertools"
] ]
[tool.setuptools.packages.find]
include = ["lm_eval*"]
# required to include yaml files in pip installation # required to include yaml files in pip installation
[tool.setuptools.package-data] [tool.setuptools.package-data]
lm_eval = ["**/*.yaml", "tasks/**/*"] lm_eval = ["**/*.yaml", "tasks/**/*"]
...@@ -63,7 +60,7 @@ ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"] ...@@ -63,7 +60,7 @@ ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"]
ipex = ["optimum"] ipex = ["optimum"]
japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"] japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"]
longbench = ["jieba", "fuzzywuzzy", "rouge"] longbench = ["jieba", "fuzzywuzzy", "rouge"]
libra=["pymorphy2"] libra = ["pymorphy2"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2", "torch"] mamba = ["mamba_ssm", "causal-conv1d==1.0.2", "torch"]
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"] math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"]
multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"] multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
...@@ -75,22 +72,17 @@ sae_lens = ["sae_lens"] ...@@ -75,22 +72,17 @@ sae_lens = ["sae_lens"]
sentencepiece = ["sentencepiece>=0.1.98"] sentencepiece = ["sentencepiece>=0.1.98"]
sparsify = ["sparsify"] sparsify = ["sparsify"]
discrim_eval = ["statsmodels==0.14.4"] discrim_eval = ["statsmodels==0.14.4"]
testing = ["pytest", "pytest-cov", "pytest-xdist"]
unitxt = ["unitxt==1.22.0"]
vllm = ["vllm>=0.4.2"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"]
zeno = ["pandas", "zeno-client"]
tasks = [ tasks = [
"lm_eval[acpbench]", "lm_eval[acpbench]",
"lm_eval[discrim_eval]", "lm_eval[discrim_eval]",
"lm_eval[ifeval]", "lm_eval[ifeval]",
"lm_eval[japanese_leaderboard]", "lm_eval[japanese_leaderboard]",
"lm_eval[longbench]", "lm_eval[longbench]",
"lm_eval[libra]", "lm_eval[libra]",
"lm_eval[mamba]", "lm_eval[mamba]",
"lm_eval[math]", "lm_eval[math]",
"lm_eval[multilingual]", "lm_eval[multilingual]",
"lm_eval[ruler]" "lm_eval[ruler]"
] ]
testing = ["pytest", "pytest-cov", "pytest-xdist"] testing = ["pytest", "pytest-cov", "pytest-xdist"]
unitxt = ["unitxt==1.22.0"] unitxt = ["unitxt==1.22.0"]
...@@ -98,14 +90,6 @@ vllm = ["vllm>=0.4.2"] ...@@ -98,14 +90,6 @@ vllm = ["vllm>=0.4.2"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"] wandb = ["wandb>=0.16.3", "pandas", "numpy"]
zeno = ["pandas", "zeno-client"] zeno = ["pandas", "zeno-client"]
[project.scripts]
lm-eval = "lm_eval.__main__:cli_evaluate"
lm_eval = "lm_eval.__main__:cli_evaluate"
[project.urls]
Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
[tool.pymarkdown] [tool.pymarkdown]
plugins.md013.enabled = false # line-length plugins.md013.enabled = false # line-length
plugins.md024.allow_different_nesting = true # no-duplicate-headers plugins.md024.allow_different_nesting = true # no-duplicate-headers
...@@ -128,9 +112,5 @@ combine-as-imports = true ...@@ -128,9 +112,5 @@ combine-as-imports = true
known-first-party = ["lm_eval"] known-first-party = ["lm_eval"]
lines-after-imports = 2 lines-after-imports = 2
# required to include yaml files in pip installation
[tool.setuptools.package-data]
lm_eval = ["**/*.yaml", "tasks/**/*"]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
include = ["lm_eval*"] include = ["lm_eval*"]
"""
Tests for the config loader pure functions.
Note: _import_function uses LRU caching, so file changes during runtime
won't be detected unless the cache is cleared.
Test coverage:
- _mk_function_ctor:
- test_mk_function_ctor_with_resolve_false: no-op lambda when resolve=False
- test_mk_function_ctor_with_resolve_true: actual function import when resolve=True
- _make_loader:
- test_make_loader_creates_loader_class: creates YAML loader with !function support
- test_make_loader_caching: loader classes cached by parameters
- _import_function:
- test_import_local_module: imports from local .py files
- test_import_nested_local_module: handles dot-separated nested paths
- test_import_standard_module: falls back to standard library imports
- test_import_caching: LRU cache behavior
- test_import_mtime_sensitivity: cache behavior with file changes
- load():
- test_load_simple_yaml: basic YAML parsing
- test_load_with_function_resolved: !function tags resolved to callables
- test_load_with_function_not_resolved: !function tags become strings
- test_load_with_includes: include files merged, main values win
- test_load_with_absolute_include: absolute path includes
- test_load_without_includes_resolution: includes preserved when disabled
- test_load_include_cycle_detection: circular includes raise error
- test_load_multiple_includes: include order precedence (later includes override earlier, main overrides all)
- test_load_recursive_includes: nested includes (main->inc1->inc2, main overrides inc1 overrides inc2)
- test_load_expanduser_path: ~ paths expanded
"""
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from lm_eval.tasks._config_loader import (
_Base,
_import_func_in_yml,
_make_loader,
_mk_function_ctor,
import_fun_from_str,
load_yaml,
)
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as td:
yield Path(td)
@pytest.fixture
def yaml_file(temp_dir):
def _create_yaml(content, filename="test.yaml"):
file_path = temp_dir / filename
file_path.write_text(content)
return file_path
return _create_yaml
@pytest.fixture
def python_module(temp_dir):
def _create_module(content, filename="utils.py"):
file_path = temp_dir / filename
file_path.write_text(content)
return file_path
return _create_module
class TestMkFunctionCtor:
"""Tests for the YAML !function constructor factory."""
def test_mk_function_ctor_with_resolve_false(self, temp_dir):
"""When resolve=False, should return a string."""
ctor = _mk_function_ctor(temp_dir, resolve=False)
loader = MagicMock()
node = MagicMock()
loader.construct_scalar.return_value = "module.function"
result = ctor(loader, node)
assert isinstance(result, str)
def test_mk_function_ctor_with_resolve_true(self, temp_dir, python_module):
"""When resolve=True, should import and return the actual function."""
# Create a local module
python_module("def test_func(x):\n return x * 2\n")
ctor = _mk_function_ctor(temp_dir, resolve=True)
loader = MagicMock()
node = MagicMock()
loader.construct_scalar.return_value = "utils.test_func"
result = ctor(loader, node)
assert callable(result)
assert result(5) == 10
class TestMakeLoader:
"""Tests for YAML loader class creation and caching."""
def test_make_loader_creates_loader_class(self, temp_dir):
loader_cls = _make_loader(temp_dir, resolve_funcs=True)
assert issubclass(loader_cls, _Base)
# !function constructor should be registered
constructors = loader_cls.yaml_constructors
assert "!function" in constructors
def test_make_loader_caching(self, temp_dir):
"""Loader classes should be cached by parameters."""
# Clear cache first
_make_loader.cache_clear()
loader1 = _make_loader(temp_dir, resolve_funcs=True)
loader2 = _make_loader(temp_dir, resolve_funcs=True)
loader3 = _make_loader(temp_dir, resolve_funcs=False)
assert loader1 is loader2 # Same params = same class
assert loader1 is not loader3 # Different params = different class
class TestImportFunction:
"""Tests for dynamic function importing with mtime-based module caching."""
def test_import_local_module(self, temp_dir, python_module):
# Create a local module
python_module("def local_func(x, y):\n return x + y\n")
func = _import_func_in_yml("utils.local_func", temp_dir)
assert callable(func)
assert func(2, 3) == 5
def test_import_nested_local_module(self, temp_dir):
"""Should handle dot-separated paths for nested modules."""
# Create nested directory structure
(temp_dir / "sub").mkdir()
(temp_dir / "sub" / "module.py").write_text(
"def nested_func():\n return 'nested'\n"
)
func = _import_func_in_yml("sub.module.nested_func", temp_dir)
assert callable(func)
assert func() == "nested"
def test_import_standard_module(self, temp_dir):
"""Falls back to standard import for non-local modules."""
# Import from standard library
func = _import_func_in_yml("os.path.join", temp_dir)
assert callable(func)
assert func("a", "b") in ("a/b", "a\\b") # Unix or Windows
def test_import_caching(self, temp_dir, python_module):
# Clear cache first
_import_func_in_yml.cache_clear()
python_module("def cached_func():\n return 42\n")
func1 = _import_func_in_yml("utils.cached_func", temp_dir)
func2 = _import_func_in_yml("utils.cached_func", temp_dir)
assert func1 is func2 # Cached
def test_import_mtime_sensitivity(self, temp_dir):
"""Verifies LRU cache behavior - file changes require cache clear."""
# Clear the LRU cache
_import_func_in_yml.cache_clear()
# Create a module
module_path = temp_dir / "test_mtime.py"
module_path.write_text("value = 1\n")
# Import it
import_key = "test_mtime.value"
value1 = _import_func_in_yml(import_key, temp_dir)
assert value1 == 1
value2 = _import_func_in_yml(import_key, temp_dir)
assert value2 == 1 # From cache
_import_func_in_yml.cache_clear()
value3 = _import_func_in_yml(import_key, temp_dir)
assert value3 == 1 # Re-imported
class TestImportFunFromStr:
"""Tests for import_fun_from_str function."""
def test_import_from_absolute_path(self, temp_dir):
"""Test importing function from absolute path."""
# Create a test module
module_path = temp_dir / "test_module.py"
module_path.write_text("def test_func(x):\n return x * 2\n")
# Import using absolute path
func = import_fun_from_str(f"{module_path.with_suffix('')}.test_func")
assert callable(func)
assert func(5) == 10
def test_import_with_py_extension(self, temp_dir):
"""Test importing when .py is included in the path."""
# Create a test module
module_path = temp_dir / "test_module.py"
module_path.write_text("def test_func(x):\n return x + 10\n")
# Import with .py in the path
func = import_fun_from_str(f"{module_path}.test_func")
assert callable(func)
assert func(5) == 15
def test_import_nested_function(self, temp_dir):
"""Test importing from nested module structure."""
# Create nested directory
(temp_dir / "subdir").mkdir()
module_path = temp_dir / "subdir" / "nested.py"
module_path.write_text("def nested_func():\n return 'nested'\n")
# Import from nested path
func = import_fun_from_str(f"{module_path.with_suffix('')}.nested_func")
assert callable(func)
assert func() == "nested"
def test_import_missing_module(self, temp_dir):
"""Test error when module doesn't exist."""
with pytest.raises(ImportError, match="Module file not found"):
import_fun_from_str(f"{temp_dir}/nonexistent.test_func")
def test_import_missing_function(self, temp_dir):
"""Test error when function doesn't exist in module."""
module_path = temp_dir / "test_module.py"
module_path.write_text("def other_func():\n pass\n")
with pytest.raises(AttributeError, match="Function 'missing_func' not found"):
import_fun_from_str(f"{module_path.with_suffix('')}.missing_func")
def test_import_invalid_format(self):
"""Test error with invalid path format."""
with pytest.raises(ValueError, match="Invalid path format"):
import_fun_from_str("/path/without/function")
def test_import_caching(self, temp_dir):
"""Test that modules are cached by mtime."""
# Clear any existing cache
import sys
keys_to_remove = [k for k in sys.modules if str(temp_dir) in k]
for k in keys_to_remove:
del sys.modules[k]
module_path = temp_dir / "cached_module.py"
module_path.write_text(
"call_count = 0\ndef func():\n global call_count\n call_count += 1\n return call_count\n"
)
# First import
func1 = import_fun_from_str(f"{module_path.with_suffix('')}.func")
_result1 = func1()
# Second import should use cached module
func2 = import_fun_from_str(f"{module_path.with_suffix('')}.func")
result2 = func2()
# Both should refer to the same module instance
assert func1 is func2
assert result2 == 2 # call_count incremented
class TestLoad:
"""Tests for the main YAML loading function with includes and function resolution."""
def test_load_simple_yaml(self, yaml_file):
content = """
task: test_task
description: A test task
metric: accuracy
"""
file_path = yaml_file(content)
result = load_yaml(file_path)
assert result["task"] == "test_task"
assert result["description"] == "A test task"
assert result["metric"] == "accuracy"
def test_load_with_function_resolved(self, yaml_file, python_module):
# Create a module with a function
python_module("def process_doc(doc):\n return doc.upper()\n")
content = """
task: test_task
doc_to_text: !function utils.process_doc
"""
file_path = yaml_file(content)
result = load_yaml(file_path, resolve_func=True)
assert callable(result["doc_to_text"])
assert result["doc_to_text"]("hello") == "HELLO"
def test_load_with_function_not_resolved(self, yaml_file):
content = """
task: test_task
doc_to_text: !function utils.process_doc
"""
file_path = yaml_file(content)
result = load_yaml(file_path, resolve_func=False)
assert isinstance(result["doc_to_text"], str)
# When resolve_functions=False, it returns the full path + function spec
assert result["doc_to_text"].endswith("utils.process_doc")
assert result["doc_to_text"] == str(file_path.parent / "utils.process_doc")
def test_load_with_includes(self, temp_dir, yaml_file):
"""Include files are merged with local values taking precedence."""
# Create included file with shared_value: 42
included_content = """
shared_metric: f1_score
shared_value: 42
"""
yaml_file(included_content, "included.yaml")
# Create main file that also defines shared_value: 100
main_content = """
include:
- included.yaml
task: main_task
shared_value: 100
"""
main_path = yaml_file(main_content, "main.yaml")
result = load_yaml(main_path, recursive=True)
assert result["task"] == "main_task"
assert result["shared_metric"] == "f1_score"
# Verify main file value (100) overrides included file value (42)
assert result["shared_value"] == 100 # Local wins
assert "include" not in result
def test_load_with_absolute_include(self, temp_dir, yaml_file):
# Create included file in different directory
other_dir = temp_dir / "other"
other_dir.mkdir()
included_path = other_dir / "included.yaml"
included_path.write_text("included_key: included_value\n")
# Create main file with absolute path
main_content = f"""
include:
- {included_path}
main_key: main_value
"""
main_path = yaml_file(main_content)
result = load_yaml(main_path, recursive=True)
assert result["main_key"] == "main_value"
assert result["included_key"] == "included_value"
def test_load_without_includes_resolution(self, yaml_file):
content = """
include:
- other.yaml
task: test_task
"""
file_path = yaml_file(content)
result = load_yaml(file_path, recursive=False)
assert result["include"] == ["other.yaml"]
assert result["task"] == "test_task"
def test_load_include_cycle_detection(self, temp_dir, yaml_file):
"""Circular includes should raise ValueError."""
# Create circular includes
yaml_file("include:\n - b.yaml\n", "a.yaml")
yaml_file("include:\n - c.yaml\n", "b.yaml")
yaml_file("include:\n - a.yaml\n", "c.yaml")
with pytest.raises(ValueError, match="Include cycle"):
load_yaml(temp_dir / "a.yaml")
def test_load_multiple_includes(self, temp_dir, yaml_file):
"""Multiple includes are processed in order, later values override earlier."""
# Create multiple included files
yaml_file("key1: value1\n", "inc1.yaml") # Sets key1 to "value1"
yaml_file(
"key2: value2\nmain_key: should_be_ignored\n", "inc2.yaml"
) # Tries to set main_key
yaml_file(
"key3: value3\nkey1: override\n", "inc3.yaml"
) # Overrides key1 to "override"
# Include order matters: inc3 comes after inc1, so its key1 value wins
main_content = """
include:
- inc1.yaml
- inc2.yaml
- inc3.yaml
main_key: main_value
"""
main_path = yaml_file(main_content)
result = load_yaml(main_path)
# Verify inc3's value overrides inc1's value for key1
assert result["key1"] == "override" # Last include wins
assert result["key2"] == "value2"
assert result["key3"] == "value3"
# Verify main file's value is NOT overridden by inc2.yaml
assert result["main_key"] == "main_value" # Main file wins over includes
def test_load_recursive_includes(self, temp_dir, yaml_file):
"""Includes can be recursive - inc1 can include inc2."""
# Create inc2.yaml (deepest level)
yaml_file(
"deep_key: deep_value\nshared_key: from_inc2\nshared_middle: inc2_middle\n",
"inc2.yaml",
)
# Create inc1.yaml that includes inc2.yaml
inc1_content = """include:
- inc2.yaml
middle_key: middle_value
shared_key: from_inc1
shared_middle: inc1_middle
"""
yaml_file(inc1_content, "inc1.yaml")
# Create main.yaml that includes inc1.yaml
main_content = """include:
- inc1.yaml
top_key: top_value
shared_key: from_main
"""
main_path = yaml_file(main_content, "main.yaml")
result = load_yaml(main_path)
# All keys should be present
assert result["deep_key"] == "deep_value" # From inc2
assert result["middle_key"] == "middle_value" # From inc1
assert result["top_key"] == "top_value" # From main
# Verify override order: main > inc1 > inc2
assert result["shared_key"] == "from_main" # Main wins
assert result["shared_middle"] == "inc1_middle" # inc1 wins over inc2
assert "include" not in result # Include directives removed
def test_load_expanduser_path(self, yaml_file):
"""Verifies that load() calls expanduser() on paths with ~."""
content = "test: value\n"
file_path = yaml_file(content)
# Mock expanduser to verify it's called and control the expansion
with patch.object(Path, "expanduser") as mock_expand:
mock_expand.return_value = file_path
result = load_yaml("~/test.yaml")
mock_expand.assert_called_once()
assert result["test"] == "value"
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
import importlib # import importlib
import os # import os
import sys # import sys
from datetime import datetime # from datetime import datetime
from typing import List, Optional, Tuple # from typing import List, Optional, Tuple
#
import pytest # import pytest
import torch # import torch
#
from lm_eval.caching.cache import PATH # from lm_eval.caching.cache import PATH
#
#
MODULE_DIR = os.path.dirname(os.path.realpath(__file__)) # MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
#
# NOTE the script this loads uses simple evaluate # # NOTE the script this loads uses simple evaluate
# TODO potentially test both the helper script and the normal script # # TODO potentially test both the helper script and the normal script
sys.path.append(f"{MODULE_DIR}/../scripts") # sys.path.append(f"{MODULE_DIR}/../scripts")
model_loader = importlib.import_module("requests_caching") # model_loader = importlib.import_module("requests_caching")
run_model_for_task_caching = model_loader.run_model_for_task_caching # run_model_for_task_caching = model_loader.run_model_for_task_caching
#
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1" # os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
DEFAULT_TASKS = ["lambada_openai", "sciq"] # DEFAULT_TASKS = ["lambada_openai", "sciq"]
#
#
@pytest.fixture(autouse=True) # @pytest.fixture(autouse=True)
def setup_and_teardown(): # def setup_and_teardown():
# Setup # # Setup
torch.use_deterministic_algorithms(False) # torch.use_deterministic_algorithms(False)
clear_cache() # clear_cache()
# Yields control back to the test function # # Yields control back to the test function
yield # yield
# Cleanup here # # Cleanup here
#
#
def clear_cache(): # def clear_cache():
if os.path.exists(PATH): # if os.path.exists(PATH):
cache_files = os.listdir(PATH) # cache_files = os.listdir(PATH)
for file in cache_files: # for file in cache_files:
file_path = f"{PATH}/{file}" # file_path = f"{PATH}/{file}"
os.unlink(file_path) # os.unlink(file_path)
#
#
# leaving tasks here to allow for the option to select specific task files # # leaving tasks here to allow for the option to select specific task files
def get_cache_files(tasks: Optional[List[str]] = None) -> Tuple[List[str], List[str]]: # def get_cache_files(tasks: Optional[List[str]] = None) -> Tuple[List[str], List[str]]:
cache_files = os.listdir(PATH) # cache_files = os.listdir(PATH)
#
file_task_names = [] # file_task_names = []
#
for file in cache_files: # for file in cache_files:
file_without_prefix = file.split("-")[1] # file_without_prefix = file.split("-")[1]
file_without_prefix_and_suffix = file_without_prefix.split(".")[0] # file_without_prefix_and_suffix = file_without_prefix.split(".")[0]
file_task_names.extend([file_without_prefix_and_suffix]) # file_task_names.extend([file_without_prefix_and_suffix])
#
return cache_files, file_task_names # return cache_files, file_task_names
#
#
def assert_created(tasks: List[str], file_task_names: List[str]): # def assert_created(tasks: List[str], file_task_names: List[str]):
tasks.sort() # tasks.sort()
file_task_names.sort() # file_task_names.sort()
#
assert tasks == file_task_names # assert tasks == file_task_names
#
#
@pytest.mark.parametrize("tasks", [DEFAULT_TASKS]) # @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_true(tasks: List[str]): # def requests_caching_true(tasks: List[str]):
run_model_for_task_caching(tasks=tasks, cache_requests="true") # run_model_for_task_caching(tasks=tasks, cache_requests="true")
#
cache_files, file_task_names = get_cache_files() # cache_files, file_task_names = get_cache_files()
print(file_task_names) # print(file_task_names)
assert_created(tasks=tasks, file_task_names=file_task_names) # assert_created(tasks=tasks, file_task_names=file_task_names)
#
#
@pytest.mark.parametrize("tasks", [DEFAULT_TASKS]) # @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_refresh(tasks: List[str]): # def requests_caching_refresh(tasks: List[str]):
run_model_for_task_caching(tasks=tasks, cache_requests="true") # run_model_for_task_caching(tasks=tasks, cache_requests="true")
#
timestamp_before_test = datetime.now().timestamp() # timestamp_before_test = datetime.now().timestamp()
#
run_model_for_task_caching(tasks=tasks, cache_requests="refresh") # run_model_for_task_caching(tasks=tasks, cache_requests="refresh")
#
cache_files, file_task_names = get_cache_files() # cache_files, file_task_names = get_cache_files()
#
for file in cache_files: # for file in cache_files:
modification_time = os.path.getmtime(f"{PATH}/{file}") # modification_time = os.path.getmtime(f"{PATH}/{file}")
assert modification_time > timestamp_before_test # assert modification_time > timestamp_before_test
#
tasks.sort() # tasks.sort()
file_task_names.sort() # file_task_names.sort()
#
assert tasks == file_task_names # assert tasks == file_task_names
#
#
@pytest.mark.parametrize("tasks", [DEFAULT_TASKS]) # @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_delete(tasks: List[str]): # def requests_caching_delete(tasks: List[str]):
# populate the data first, rerun this test within this test for additional confidence # # populate the data first, rerun this test within this test for additional confidence
# test_requests_caching_true(tasks=tasks) # # test_requests_caching_true(tasks=tasks)
#
run_model_for_task_caching(tasks=tasks, cache_requests="delete") # run_model_for_task_caching(tasks=tasks, cache_requests="delete")
#
cache_files, file_task_names = get_cache_files() # cache_files, file_task_names = get_cache_files()
#
assert len(cache_files) == 0 # assert len(cache_files) == 0
#
#
# useful for locally running tests through the debugger # # useful for locally running tests through the debugger
if __name__ == "__main__": # if __name__ == "__main__":
#
def run_tests(): # def run_tests():
tests = [ # tests = [
# test_requests_caching_true, # # test_requests_caching_true,
# test_requests_caching_refresh, # # test_requests_caching_refresh,
# test_requests_caching_delete, # # test_requests_caching_delete,
] # ]
# Lookups of global names within a loop is inefficient, so copy to a local variable outside of the loop first # # Lookups of global names within a loop is inefficient, so copy to a local variable outside of the loop first
default_tasks = DEFAULT_TASKS # default_tasks = DEFAULT_TASKS
for test_func in tests: # for test_func in tests:
clear_cache() # clear_cache()
test_func(tasks=default_tasks) # test_func(tasks=default_tasks)
#
print("Tests pass") # print("Tests pass")
#
run_tests() # run_tests()
"""Tests for the task index builder that discovers YAML task configurations.
Test coverage:
- TaskIndexBuilder._kind_of: identifies task/group/tag/task_list/py_task
- TaskIndexBuilder._iter_yaml_files: finds YAML files, ignores __pycache__
- TaskIndexBuilder._process_cfg: creates correct TaskEntry for each type
- TaskIndexBuilder._register_tags: creates TAG entries for task tags
- TaskIndexBuilder.build: discovers all task types in directory tree
"""
import tempfile
from pathlib import Path
import pytest
from lm_eval.tasks.index import Kind, TaskIndex
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as td:
yield Path(td)
@pytest.fixture
def yaml_file(temp_dir):
def _create_yaml(content, path="test.yaml"):
file_path = temp_dir / path
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text(content)
return file_path
return _create_yaml
class TestKindOf:
"""Tests for identifying task configuration types."""
def test_kind_of_task(self):
"""Single task with string name."""
cfg = {"task": "my_task", "dataset_path": "data"}
assert TaskIndex._kind_of(cfg) == Kind.TASK
def test_kind_of_group(self):
"""Group has task as list."""
cfg = {"task": ["task1", "task2"], "group": "my_group"}
assert TaskIndex._kind_of(cfg) == Kind.GROUP
def test_kind_of_py_task(self):
"""Python task has class field."""
cfg = {"task": "my_task", "class": "tasks.MyTask"}
assert TaskIndex._kind_of(cfg) == Kind.PY_TASK
def test_kind_of_task_list(self):
"""Task list has task_list field."""
cfg = {"task_list": ["task1", "task2"]}
assert TaskIndex._kind_of(cfg) == Kind.TASK_LIST
def test_kind_of_unknown(self):
"""Unknown config raises ValueError."""
cfg = {"unknown": "field"}
with pytest.raises(ValueError, match="Unknown config shape"):
TaskIndex._kind_of(cfg)
class TestIterYamlFiles:
"""Tests for YAML file discovery."""
def test_iter_yaml_files_simple(self, temp_dir):
"""Finds .yaml files in directory tree."""
# Create some yaml files
(temp_dir / "task1.yaml").touch()
(temp_dir / "subdir").mkdir()
(temp_dir / "subdir" / "task2.yaml").touch()
(temp_dir / "other.txt").touch()
builder = TaskIndex()
yaml_files = list(builder._iter_yaml_files(temp_dir))
assert len(yaml_files) == 2
names = {f.name for f in yaml_files}
assert names == {"task1.yaml", "task2.yaml"}
def test_iter_yaml_files_ignores_pycache(self, temp_dir):
"""Ignores files in __pycache__ directories."""
(temp_dir / "task.yaml").touch()
(temp_dir / "__pycache__").mkdir()
(temp_dir / "__pycache__" / "ignored.yaml").touch()
(temp_dir / ".ipynb_checkpoints").mkdir()
(temp_dir / ".ipynb_checkpoints" / "also_ignored.yaml").touch()
builder = TaskIndex()
yaml_files = list(builder._iter_yaml_files(temp_dir))
assert len(yaml_files) == 1
assert yaml_files[0].name == "task.yaml"
class TestProcessCfg:
"""Tests for processing individual config files."""
def test_process_task(self, temp_dir):
"""Regular task creates TASK entry."""
cfg = {"task": "my_task", "tag": ["tag1", "tag2"]}
path = temp_dir / "task.yaml"
index = {}
builder = TaskIndex()
builder.process_cfg(cfg, path, index)
assert "my_task" in index
entry = index["my_task"]
assert entry.name == "my_task"
assert entry.kind == Kind.TASK
assert entry.yaml_path == path
assert entry.tags == {"tag1", "tag2"}
def test_process_group(self, temp_dir):
"""Group creates GROUP entry."""
cfg = {"task": ["t1", "t2"], "group": "my_group", "tag": ["grp_tag"]}
path = temp_dir / "group.yaml"
index = {}
builder = TaskIndex()
builder.process_cfg(cfg, path, index)
assert "my_group" in index
entry = index["my_group"]
assert entry.name == "my_group"
assert entry.kind == Kind.GROUP
assert entry.yaml_path == path
assert entry.tags == {"grp_tag"}
def test_process_py_task(self, temp_dir):
"""Python task creates PY_TASK entry."""
cfg = {"task": "py_task", "class": "MyTask", "tag": ["py_tag"]}
path = temp_dir / "py_task.yaml"
index = {}
builder = TaskIndex()
builder.process_cfg(cfg, path, index)
assert "py_task" in index
entry = index["py_task"]
assert entry.name == "py_task"
assert entry.kind == Kind.PY_TASK
assert entry.yaml_path is None # Python tasks don't store yaml_path
assert entry.tags == {"py_tag"}
def test_process_task_list(self, temp_dir):
"""Task list creates entries for each task."""
cfg = {
"task_list": [
"simple_task",
{"task": "complex_task", "tag": ["tag1", "tag2"]},
],
}
path = temp_dir / "list.yaml"
index = {}
builder = TaskIndex()
# The implementation has a bug - it calls entry.get() on string entries
# This test documents the current behavior which will fail
with pytest.raises(AttributeError, match="'str' object has no attribute 'get'"):
builder.process_cfg(cfg, path, index)
def test_process_task_list_dict_entries(self, temp_dir):
"""Task list with only dict entries works."""
cfg = {
"task_list": [
{"task": "task1"},
{"task": "task2", "tag": ["tag1", "tag2"]},
],
}
path = temp_dir / "list.yaml"
index = {}
builder = TaskIndex()
builder.process_cfg(cfg, path, index)
# Task without tags
assert "task1" in index
task1 = index["task1"]
assert task1.kind == Kind.TASK
assert task1.yaml_path == path
assert task1.tags == set()
# Task with tags
assert "task2" in index
task2 = index["task2"]
assert task2.kind == Kind.TASK
assert task2.yaml_path == path
assert task2.tags == {"tag1", "tag2"}
class TestRegisterTags:
"""Tests for tag registration."""
def test_register_single_tag(self):
"""Single tag creates TAG entry."""
index = {}
builder = TaskIndex()
builder._register_tags("task1", "my_tag", index)
assert "my_tag" in index
tag_entry = index["my_tag"]
assert tag_entry.kind == Kind.TAG
assert tag_entry.yaml_path is None
assert "task1" in tag_entry.tags # TAG entries use tags set for task names
def test_register_multiple_tags(self):
"""Multiple tags create multiple TAG entries."""
index = {}
builder = TaskIndex()
builder._register_tags("task1", ["tag1", "tag2"], index)
assert "tag1" in index
assert "tag2" in index
assert "task1" in index["tag1"].tags
assert "task1" in index["tag2"].tags
def test_register_tags_accumulates(self):
"""Multiple tasks can have same tag."""
index = {}
builder = TaskIndex()
builder._register_tags("task1", "shared_tag", index)
builder._register_tags("task2", "shared_tag", index)
assert "shared_tag" in index
tag_entry = index["shared_tag"]
assert tag_entry.tags == {"task1", "task2"}
class TestBuild:
"""Tests for the main build method."""
def test_build_empty_directory(self, temp_dir):
"""Empty directory returns empty index."""
builder = TaskIndex()
index = builder.build([temp_dir])
assert index == {}
def test_build_single_task(self, temp_dir, yaml_file):
"""Single task file is discovered."""
yaml_file("task: my_task\ndataset_path: data\n")
builder = TaskIndex()
index = builder.build([temp_dir])
assert len(index) == 1
assert "my_task" in index
assert index["my_task"].kind == Kind.TASK
def test_build_mixed_types(self, temp_dir, yaml_file):
"""Discovers various task types."""
# Regular task with list tag format
yaml_file("task: task1\ntag: [common]\n", "task1.yaml")
# Group
yaml_file("task: [t1, t2]\ngroup: group1\n", "group1.yaml")
# Task list with only dict entries (to avoid the bug)
yaml_file(
"task_list:\n - task: task2\n - task: task3\n tag: [common]\n",
"list.yaml",
)
# Python task
yaml_file("task: py_task\nclass: MyClass\n", "python.yaml")
builder = TaskIndex()
index = builder.build([temp_dir])
# Check all entries exist
assert "task1" in index
assert "group1" in index
assert "task2" in index
assert "task3" in index
assert "py_task" in index
assert "common" in index # Tag entry
# Check types
assert index["task1"].kind == Kind.TASK
assert index["group1"].kind == Kind.GROUP
assert index["task2"].kind == Kind.TASK
assert index["task3"].kind == Kind.TASK
assert index["py_task"].kind == Kind.PY_TASK
assert index["common"].kind == Kind.TAG
# Check tag has both tasks
assert index["common"].tags == {"task1", "task3"}
def test_build_nested_directories(self, temp_dir, yaml_file):
"""Discovers tasks in nested directories."""
yaml_file("task: root_task\n", "root.yaml")
yaml_file("task: sub_task\n", "subdir/sub.yaml")
yaml_file("task: deep_task\n", "subdir/deeper/deep.yaml")
builder = TaskIndex()
index = builder.build([temp_dir])
assert len(index) == 3
assert all(name in index for name in ["root_task", "sub_task", "deep_task"])
def test_build_skips_invalid_yaml(self, temp_dir, yaml_file):
"""Skips files that fail to parse."""
yaml_file("task: valid_task\n", "valid.yaml")
yaml_file("invalid: [\n", "invalid.yaml") # Invalid YAML
builder = TaskIndex()
index = builder.build([temp_dir])
assert len(index) == 1
assert "valid_task" in index
def test_build_multiple_paths(self, temp_dir):
"""Can search multiple root paths."""
# Create two separate directories
dir1 = temp_dir / "dir1"
dir2 = temp_dir / "dir2"
dir1.mkdir()
dir2.mkdir()
(dir1 / "task1.yaml").write_text("task: task1\n")
(dir2 / "task2.yaml").write_text("task: task2\n")
builder = TaskIndex()
index = builder.build([dir1, dir2])
assert len(index) == 2
assert "task1" in index
assert "task2" in index
...@@ -64,10 +64,10 @@ def test_python_task_inclusion( ...@@ -64,10 +64,10 @@ def test_python_task_inclusion(
verbosity="INFO", include_path=str(custom_task_files_dir) verbosity="INFO", include_path=str(custom_task_files_dir)
) )
# check if python tasks enters the global task_index # check if python tasks enters the global task_index
assert custom_task_name in task_manager.task_index assert custom_task_name in task_manager._index
# check if subtask is present # check if subtask is present
assert custom_task_name in task_manager.all_subtasks assert custom_task_name in task_manager._index
# check if tag is present # check if tag is present
assert custom_task_tag in task_manager.all_tags assert custom_task_tag in task_manager._index
# check if it can be loaded by tag (custom_task_tag) # check if it can be loaded by tag (custom_task_tag)
assert custom_task_name in task_manager.load_task_or_group(custom_task_tag) assert custom_task_name in task_manager.load_task_or_group(custom_task_tag)
#!/usr/bin/env python3
"""
Walkthrough tests using real dataset configurations.
These tests use YAML configs with existing datasets (hellaswag) to enable
complete code walkthrough of the task loading system, including:
- Basic task loading
- Task list functionality
- Group functionality
- Include inheritance
- Issue #2158 fix (include processing preserving task names)
"""
import os
import pytest
from lm_eval.tasks import TaskManager, get_task_dict
class TestWalkthroughConfigs:
"""Test walkthrough configurations for easier code demonstration"""
@pytest.fixture(autouse=True)
def setup_task_manager(self):
"""Set up TaskManager with test configs directory"""
test_configs_dir = os.path.join(os.path.dirname(__file__), "test_configs")
self.tm = TaskManager(include_path=test_configs_dir, include_defaults=False)
def test_simple_task_loading(self):
"""Test basic task loading - walkthrough starting point"""
# Simple task should be indexed
assert "simple_task" in self.tm.all_tasks
assert self.tm._name_is_task("simple_task")
# Load the task
task_dict = get_task_dict(["simple_task"], task_manager=self.tm)
assert "simple_task" in task_dict
# Verify task configuration
task_obj = task_dict["simple_task"]
assert hasattr(task_obj, "config")
assert task_obj.config.task == "simple_task"
def test_task_list_functionality(self):
"""Test task_list feature - multiple tasks sharing config"""
# All task_list tasks should be indexed as individual tasks
expected_tasks = ["task_list_fs0", "task_list_fs1", "task_list_fs3"]
for task_name in expected_tasks:
assert task_name in self.tm.all_tasks, f"Task {task_name} not indexed"
assert self.tm._name_is_task(task_name), (
f"Task {task_name} not recognized as task"
)
# Load all tasks from the task_list
task_dict = get_task_dict(expected_tasks, task_manager=self.tm)
# Each should be a separate task object
assert len(task_dict) == 3
for task_name in expected_tasks:
assert task_name in task_dict
task_obj = task_dict[task_name]
assert task_obj.config.task == task_name
# Verify different num_fewshot values were applied
assert task_dict["task_list_fs0"].config.num_fewshot == 0
assert task_dict["task_list_fs1"].config.num_fewshot == 1
assert task_dict["task_list_fs3"].config.num_fewshot == 3
def test_group_functionality(self):
"""Test group loading with task-specific overrides"""
# Group should be indexed
assert "test_group" in self.tm.all_groups
assert self.tm._name_is_group("test_group")
# Load the group
task_dict = get_task_dict(["test_group"], task_manager=self.tm)
# Should contain the group object and its subtasks
assert len(task_dict) == 1
group_obj = list(task_dict.keys())[0]
subtasks = task_dict[group_obj]
# Check expected subtasks
expected_subtasks = ["group_task_fs0", "group_task_fs2"]
for subtask_name in expected_subtasks:
assert subtask_name in subtasks
# Verify different configurations were applied
fs0_task = subtasks["group_task_fs0"]
fs2_task = subtasks["group_task_fs2"]
assert fs0_task.config.num_fewshot == 0
assert fs2_task.config.num_fewshot == 2
def test_include_inheritance(self):
"""Test include functionality and inheritance"""
# Test direct include tasks (these were created as separate files)
include_tasks = ["include_task_fs0", "include_task_fs1", "include_task_fs5"]
for task_name in include_tasks:
assert task_name in self.tm.all_tasks
# Load tasks that use include
task_dict = get_task_dict(
include_tasks[:1], task_manager=self.tm
) # Just test first one
# Should inherit from base config
task_obj = task_dict["include_task_fs0"]
# Should inherit dataset_path from include
assert task_obj.config.dataset_path == "json"
# Should inherit output_type from include
assert task_obj.config.output_type == "multiple_choice"
# Should preserve specific task name (not base_task_name)
assert task_obj.config.task == "include_task_fs0"
# Should have overridden num_fewshot
assert task_obj.config.num_fewshot == 0
def test_issue_2158_fix_demo(self):
"""
Test issue #2158 fix - multiple tasks with same include in group.
This demonstrates the specific scenario that was failing before the fix.
"""
# Group with multiple tasks using same include should work
assert "include_group" in self.tm.all_groups
# This should NOT raise a duplicate detection error
# Before the fix, this would fail with:
# "Please call groups which overlap their constituent tasks in separate evaluation runs"
task_dict = get_task_dict(["include_group"], task_manager=self.tm)
# Should successfully load the group
assert len(task_dict) == 1
group_obj = list(task_dict.keys())[0]
subtasks = task_dict[group_obj]
# Check all expected tasks are present with correct names
expected_tasks = ["include_task_fs0", "include_task_fs1", "include_task_fs5"]
for task_name in expected_tasks:
assert task_name in subtasks, f"Task {task_name} missing from group"
task_obj = subtasks[task_name]
# CRITICAL: Task name should be preserved, not overwritten by include
assert task_obj.config.task == task_name
# Should inherit base config from include
assert task_obj.config.dataset_path == "json"
assert task_obj.config.output_type == "multiple_choice"
# Verify different num_fewshot values
assert subtasks["include_task_fs0"].config.num_fewshot == 0
assert subtasks["include_task_fs1"].config.num_fewshot == 1
assert subtasks["include_task_fs5"].config.num_fewshot == 5
def test_config_types_detection(self):
"""Test that different config types are correctly detected"""
# Load various config types to test detection methods
configs = [
# Simple task config
{"task": "walkthrough_simple_task"},
# Group config
{"group": "test_group", "task": ["task1", "task2"]},
# Task list config (would need to be loaded from file)
]
# Test config detection methods
assert self.tm._config_is_task(configs[0])
assert not self.tm._config_is_group()
assert not self.tm._config_is_task_list(configs[0])
assert not self.tm._config_is_task(configs[1])
assert self.tm._config_is_group()
assert not self.tm._config_is_task_list(configs[1])
# Test task_list detection with actual config
task_list_config = {"task_list": [{"task": "task1"}, {"task": "task2"}]}
assert self.tm._config_is_task_list(task_list_config)
assert not self.tm._config_is_task(task_list_config)
assert not self.tm._config_is_group()
if __name__ == "__main__":
pytest.main([__file__, "-v"])
...@@ -12,6 +12,7 @@ from lm_eval.api.metrics import ( ...@@ -12,6 +12,7 @@ from lm_eval.api.metrics import (
) )
from lm_eval.models.utils import Collator from lm_eval.models.utils import Collator
from lm_eval.utils import ( from lm_eval.utils import (
apply_template,
get_rolling_token_windows, get_rolling_token_windows,
make_disjoint_window, make_disjoint_window,
) )
...@@ -396,3 +397,95 @@ def test_aggregate_stderrs(samples): ...@@ -396,3 +397,95 @@ def test_aggregate_stderrs(samples):
mean_stderr(list(itertools.chain.from_iterable(samples))), mean_stderr(list(itertools.chain.from_iterable(samples))),
atol=1.0e-3, atol=1.0e-3,
) )
def test_apply_template():
"""Test the apply_template function with various scenarios."""
# Test basic variable substitution
result = apply_template("Hello {{name}}!", {"name": "World"})
assert result == "Hello World!"
# Test multiple variables
result = apply_template(
"{{greeting}} {{name}}!", {"greeting": "Hi", "name": "Alice"}
)
assert result == "Hi Alice!"
# Test missing variable (should raise error due to StrictUndefined)
with pytest.raises(Exception): # Jinja2 will raise UndefinedError
apply_template("Hello {{missing}}!", {})
# Test empty template
result = apply_template("", {})
assert result == ""
# Test template with no variables
result = apply_template("Static text", {"unused": "variable"})
assert result == "Static text"
# Test numeric variables
result = apply_template("Count: {{count}}", {"count": 42})
assert result == "Count: 42"
# Test boolean variables
result = apply_template("Flag: {{flag}}", {"flag": True})
assert result == "Flag: True"
# Test list variables
result = apply_template("Items: {{items}}", {"items": [1, 2, 3]})
assert result == "Items: [1, 2, 3]"
# Test regex_replace filter
result = apply_template(
"{{text | regex_replace('[0-9]+', 'X')}}", {"text": "abc123def456"}
)
assert result == "abcXdefX"
# Test regex_replace with count parameter
result = apply_template(
"{{text | regex_replace('[0-9]+', 'X', 1)}}", {"text": "abc123def456"}
)
assert result == "abcXdef456"
# Test complex template with loops
result = apply_template(
"{% for item in items %}{{item}} {% endfor %}", {"items": ["a", "b", "c"]}
)
assert result == "a b c "
# Test conditional template
result = apply_template("{% if flag %}Yes{% else %}No{% endif %}", {"flag": True})
assert result == "Yes"
result = apply_template("{% if flag %}Yes{% else %}No{% endif %}", {"flag": False})
assert result == "No"
# Test whitespace handling (keep_trailing_newline=True)
result = apply_template("Line 1\nLine 2\n", {})
assert result == "Line 1\nLine 2\n"
def test_apply_template_lazy_initialization():
"""Test that the Jinja2 Environment is lazily initialized."""
# Clear any existing environment to test fresh initialization
if hasattr(apply_template, "_env"):
delattr(apply_template, "_env")
# Environment should not exist before first call
assert not hasattr(apply_template, "_env")
# First call should create the environment
apply_template("{{test}}", {"test": "value"})
assert hasattr(apply_template, "_env")
# Store reference to the environment
env = apply_template._env
# Second call should reuse the same environment
apply_template("{{test}}", {"test": "value"})
assert apply_template._env is env # Same object reference
# Environment should have the custom regex_replace filter
assert "regex_replace" in apply_template._env.filters
import os import os
from typing import List, Union from typing import List, Union
from lm_eval.utils import load_yaml_config from lm_eval.tasks._config_loader import load_yaml
# {{{CI}}} # {{{CI}}}
...@@ -12,7 +12,7 @@ from lm_eval.utils import load_yaml_config ...@@ -12,7 +12,7 @@ from lm_eval.utils import load_yaml_config
# reads a text file and returns a list of words # reads a text file and returns a list of words
# used to read the output of the changed txt from tj-actions/changed-files # used to read the output of the changed txt from tj-actions/changed-files
def load_changed_files(file_path: str) -> List[str]: def load_changed_files(file_path: str) -> List[str]:
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, encoding="utf-8") as f:
content = f.read() content = f.read()
words_list = list(content.split()) words_list = list(content.split())
return words_list return words_list
...@@ -26,7 +26,7 @@ def parser(full_path: List[str]) -> List[str]: ...@@ -26,7 +26,7 @@ def parser(full_path: List[str]) -> List[str]:
_output = set() _output = set()
for x in full_path: for x in full_path:
if x.endswith(".yaml") and os.path.exists(x): if x.endswith(".yaml") and os.path.exists(x):
config = load_yaml_config(x, mode="simple") config = load_yaml(x, recursive=True, resolve_func=True)
if isinstance(config["task"], str): if isinstance(config["task"], str):
_output.add(config["task"]) _output.add(config["task"])
elif isinstance(config["task"], list): elif isinstance(config["task"], list):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment