Commit 57adbd35 authored by Baber's avatar Baber
Browse files

refactor configs to files

parent 04e74420
...@@ -10,7 +10,7 @@ eval_logger = logging.getLogger(__name__) ...@@ -10,7 +10,7 @@ eval_logger = logging.getLogger(__name__)
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
DEFAULTS = { DEFAULTS = {
"model": {"max_length": 2048}, "model": {"max_length": 2048},
"tasks": {"generate_until": {"max_length": 2048}}, "tasks": {"generate_until": {"max_gen_toks": 256}},
} }
......
...@@ -5,9 +5,6 @@ import random ...@@ -5,9 +5,6 @@ import random
import re import re
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass, field
from functools import cached_property
from inspect import getsource
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
...@@ -28,18 +25,11 @@ from tqdm import tqdm ...@@ -28,18 +25,11 @@ from tqdm import tqdm
from typing_extensions import deprecated from typing_extensions import deprecated
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance, OutputType from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation,
get_metric,
get_metric_aggregation,
is_higher_better,
)
from lm_eval.caching.cache import load_from_cache, save_to_cache from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.config.metric import MetricConfig
from lm_eval.config.task import TaskConfig
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
...@@ -52,403 +42,12 @@ ALL_OUTPUT_TYPES = [ ...@@ -52,403 +42,12 @@ ALL_OUTPUT_TYPES = [
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
from lm_eval.api.filter import FilterEnsemble pass
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
@dataclass
class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
aggregation_fn: Optional[Callable] = None
higher_is_better: bool = True
hf_evaluate: bool = False
is_elementwise: bool = True
@cached_property
def metric_name(self) -> str:
return self.name
@cached_property
def aggregation(self) -> Callable:
if self.aggregation_fn is None:
return get_aggregation(self.name)
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool:
if self.higher_is_better is None:
return is_higher_better(self.name)
return self.higher_is_better
def compute_metric(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs})
def compute_aggregation(self, values: List[Any]) -> Any:
"""Computes the aggregation of the metric values."""
if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(values)
@dataclass
class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats: int = 1
metric_fn: Optional[str, Callable] = "pass@N"
kwargs: Optional[dict] = None
@dataclass
class FilterConfig:
"""Encapsulates information about a single filter."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
@dataclass
class FewshotConfig:
sampler: str
samples: list[dict]
process_docs: Optional[Callable] = None
fewshot_indices: Optional[list[int]] = None
@dataclass
class TemplateConfig:
"""Encapsulates information about a template."""
template: str
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
description: str
context_prefix: str
prefix_delimiter: str
context_delimiter: str
answer_suffix: str
target_delimiter: str
choice_format: Optional[str]
choice_delimiter: Optional[str]
fewshot_delimiter: str
metric_list: Optional[Union[list[str], list[MetricConfig]]] = field(
default_factory=lambda: ["acc", "acc_norm"]
)
@dataclass
class MCQTemplateConfig:
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
A. <doc_to_choice(doc)[0]>
B. <doc_to_choice(doc)[1]>
C. <doc_to_choice(doc)[2]>
D. <doc_to_choice(doc)[3]>
Answer:` doc_to_choice(doc)` for each choice.
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
template = "mcq"
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = "\n"
choice_format: Optional[str] = "letters"
choice_delimiter: Optional[str] = "\n"
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list[MetricConfig]] = field(default_factory=lambda: ["acc"])
@dataclass
class ClozeTemplateConfig:
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
Answer:` <doc_to_target(doc)>`
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
template: str = "cloze"
description: str = ""
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = " "
choice_format: Optional[str] = None
choice_delimiter: Optional[str] = None
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list[MetricConfig]] = field(
default_factory=lambda: ["acc", "acc_norm"]
)
@dataclass
class DatasetConfig:
"""Encapsulates information about a dataset."""
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
custom_dataset: Optional[Callable] = None
@dataclass
class TaskConfig(dict):
# task naming/registry
task: Optional[str] = None
task_alias: Optional[str] = None
tag: Optional[Union[str, list]] = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
training_split: Optional[str] = None
validation_split: Optional[str] = None
test_split: Optional[str] = None
fewshot_split: Optional[str] = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str, None] = None
doc_to_audio: Union[Callable, str, None] = None
unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: Optional[dict] = None
# runtime configuration options
num_fewshot: Optional[int] = None
# scoring options
metric_list: Optional[list] = None
output_type: OutputType = "generate_until"
generation_kwargs: Optional[dict] = None
repeats: int = 1
filter_list: Optional[list[dict]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
_metric_list: list[MetricConfig] = None
_filter_list: list[FilterConfig] = None
def __post_init__(self) -> None:
if self.generation_kwargs is not None:
if self.output_type != "generate_until":
eval_logger.warning(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
self.generation_kwargs["temperature"]
)
if "until" not in self.generation_kwargs:
eval_logger.warning(
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
)
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "generate_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {
"until": (
None
if self.fewshot_delimiter is None
else [self.fewshot_delimiter]
),
"do_sample": False,
"temperature": 0,
}
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
if self.metric_list and not all("metric" in cfg for cfg in self.metric_list):
raise ValueError("each entry in metric_list must include a 'metric' key")
def get_metrics(self) -> list["MetricConfig"]:
metrics = []
if self.metric_list is None:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info(
f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}"
)
metrics.extend(
MetricConfig(
name=metric_name,
fn=get_metric(metric_name),
aggregation_fn=get_metric_aggregation(metric_name),
higher_is_better=is_higher_better(metric_name),
)
for metric_name in _metric_list
)
else:
# ---------- 2. How will the samples be evaluated ----------
for metric_config in self.metric_list:
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
_hf_evaluate_metric: bool = metric_config.get("hf_evaluate", False)
_metric_fn = None
_aggregation = None
if self.process_results is not None:
# User will compute metrics inside `process_results()`
_metric_name = None
_metric_fn_kwargs = {}
elif callable(metric_name):
# User passed a function object
_metric_name = metric_name.__name__
_metric_fn = metric_name.__call__
else:
# Normal: look up by name
_metric_name = get_metric(metric_name, _hf_evaluate_metric)
# ---------- 3. Decide how to aggregate examples ----------
if "aggregation" in metric_config:
if isinstance(_agg_name := metric_config["aggregation"], str):
_aggregation = get_aggregation(_agg_name)
elif callable(_agg_name): # noqa: E721
_aggregation = metric_config["aggregation"]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
_aggregation = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[_aggregation]}"
)
# ---------- 4. Determine “higher-is-better” semantics ----------
if "higher_is_better" in metric_config:
_higher_is_better = metric_config["higher_is_better"]
else:
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
_higher_is_better = is_higher_better(metric_name)
metrics.append(
MetricConfig(
name=_metric_name,
fn=_metric_fn,
kwargs=_metric_fn_kwargs,
aggregation_fn=_aggregation,
higher_is_better=_higher_is_better,
hf_evaluate=_hf_evaluate_metric,
)
)
return metrics
def get_filters(self) -> list["FilterEnsemble"]:
if not self.filter_list:
eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats."
)
return [build_filter_ensemble("none", [["take_first", None]])]
else:
def _strip_fn(d: dict) -> dict:
return {k: v for k, v in d.items() if k != "function"}
configs = (
self.filter_list.values()
if isinstance(self.filter_list, dict)
else self.filter_list
)
return [
build_filter_ensemble(
filter_name=cfg["name"],
components=[[_strip_fn(f) for f in cfg["filter"]]],
)
for cfg in configs
]
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict:
"""Return a printable dict with Nones stripped and callables serialised.
:return: dict
A printable dictionary version of the TaskConfig object.
"""
def _maybe_serialize(val):
return (
self.serialize_function(val, keep_callable=keep_callable)
if callable(val)
else val
)
cfg = asdict(self)
return {
k: [{mk: _maybe_serialize(mv) for mk, mv in md.items()} for md in v]
if k == "metric_list"
else _maybe_serialize(v)
for k, v in cfg.items()
if v is not None
}
def serialize_function(
self, value: Union[Callable, str], keep_callable=False
) -> Union[Callable, str]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
class Task(abc.ABC): class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems, """A task represents an entire benchmark including its dataset, problems,
answers, and evaluation methods. See BoolQ for a simple example implementation answers, and evaluation methods. See BoolQ for a simple example implementation
...@@ -1040,13 +639,13 @@ class ConfigurableTask(Task): ...@@ -1040,13 +639,13 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None: if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name self.DATASET_NAME = self.config.dataset_name
self.metric_list: list[MetricConfig] = self.config.get_metrics() self.metric_list: list[MetricConfig] = self.config.get_metrics
self.download(self.config.dataset_kwargs) self.download(self.config.dataset_kwargs)
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
self._filters = self.config.get_filters() self._filters = self.config.get_filters
if self.config.use_prompt is not None: if self.config.use_prompt is not None:
eval_logger.info(f"loading prompt {self.config.use_prompt}") eval_logger.info(f"loading prompt {self.config.use_prompt}")
...@@ -1056,31 +655,11 @@ class ConfigurableTask(Task): ...@@ -1056,31 +655,11 @@ class ConfigurableTask(Task):
else: else:
self.prompt = None self.prompt = None
if self.fewshot_docs() is not None: if self.config.fewshot_cfg.num > 0 and self.fewshot_docs() is not None:
self.fewshot_rnd = ( self.fewshot_rnd = random.Random()
random.Random() self.sampler = self.config.fewshot_cfg.init_sampler(
) # setting with no seed, to be overridden at a later time list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
config_sampler: Union[str, Callable] = (
self.config.fewshot_config.get("sampler", "default")
if self.config.fewshot_config
else "default"
) )
if isinstance(config_sampler, str):
self.sampler = samplers.get_sampler(config_sampler)(
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
)
elif callable(config_sampler) and issubclass(
config_sampler, samplers.ContextSampler
):
self.sampler = config_sampler(
docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
)
else:
raise TypeError(
f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
f"not {type(config_sampler)}"
)
self.task_docs = self.eval_docs self.task_docs = self.eval_docs
# Test One Doc # Test One Doc
...@@ -1203,30 +782,21 @@ class ConfigurableTask(Task): ...@@ -1203,30 +782,21 @@ class ConfigurableTask(Task):
return self.dataset[self.config.test_split] return self.dataset[self.config.test_split]
def fewshot_docs(self): def fewshot_docs(self):
if self.config.fewshot_split is not None: docs = self.config.fewshot_cfg.get_docs(self.dataset)
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.fewshot_split]) if docs is not None:
return self.dataset[self.config.fewshot_split] return docs
elif (
self.config.fewshot_config is not None # Fallback to parent implementation
and self.config.fewshot_config.get("samples", None) is not None if _num_fewshot := getattr(self.config, "num_fewshot"):
): if isinstance(_num_fewshot, int) and _num_fewshot > 0:
if isinstance(self.config.fewshot_config["samples"], list):
return self.config.fewshot_config["samples"]
elif callable(self.config.fewshot_config["samples"]):
return self.config.fewshot_config["samples"]()
else:
raise Exception(
"`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list."
)
else:
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
eval_logger.warning( eval_logger.warning(
f"[Task: {self.config.task}] " f"[Task: {self.config.task}] "
"num_fewshot > 0 but fewshot_split is None. " "num_fewshot > 0 but no fewshot source configured. "
"using preconfigured rule." "Using preconfigured rule."
) )
return super().fewshot_docs()
return super().fewshot_docs()
@staticmethod @staticmethod
def append_target_question( def append_target_question(
...@@ -1441,7 +1011,7 @@ class ConfigurableTask(Task): ...@@ -1441,7 +1011,7 @@ class ConfigurableTask(Task):
""" """
return doc return doc
def doc_to_text(self, doc: dict, doc_to_text: Optional[int, str, Callable] = None): def doc_to_text(self, doc: dict, doc_to_text: Union[int, str, Callable] = None):
if self.prompt is not None: if self.prompt is not None:
doc_to_text = self.prompt doc_to_text = self.prompt
elif doc_to_text is not None: elif doc_to_text is not None:
......
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Callable, List, Optional
@dataclass
class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
aggregation_fn: Optional[Callable] = None
higher_is_better: bool = True
hf_evaluate: bool = False
is_elementwise: bool = True
@cached_property
def metric_name(self) -> str:
return self.name
@cached_property
def aggregation(self) -> Callable:
from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None:
return get_aggregation(self.name)
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool:
from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None:
return is_higher_better(self.name)
return self.higher_is_better
def compute_metric(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs})
def compute_aggregation(self, values: List[Any]) -> Any:
"""Computes the aggregation of the metric values."""
if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(values)
import logging
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize
if TYPE_CHECKING:
from lm_eval.api.samplers import ContextSampler
from lm_eval.api.task import Task, eval_logger
eval_logger = logging.getLogger(__name__)
@dataclass
class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats: int = 1
metric_fn: Union[str, Callable] = "pass@N"
kwargs: Optional[dict] = None
@dataclass
class FilterConfig:
"""Encapsulates information about a single filter."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
@dataclass
class FewshotConfig:
num: int = 0
split: Optional[str] = None
sampler: Union[str, Callable] = "default"
samples: Union[Callable[[], list[dict]], list[dict], None] = None
process_docs: Optional[Callable[[list[dict]], Iterable[dict]]] = None
fewshot_indices: Optional[list[int]] = None
rnd: int = field(init=False, default=False)
def __post_init__(self) -> None:
if self.samples is not None and not (
isinstance(self.samples, list) or callable(self.samples)
):
raise TypeError(
"samples must be either list[dict] or callable returning list[dict]"
)
if self.split is not None and self.samples is not None:
eval_logger.warning(
"Both split and samples are configured; split will take precedence"
)
@property
def has_source(self) -> bool:
"""Check if any fewshot source is configured."""
return self.split is not None or self.samples is not None
def _get_raw_docs(
self, dataset
) -> Union[list[dict], Callable[[], Iterable[dict]], None]:
"""Get raw documents from configured source."""
if self.split is not None:
return dataset[self.split]
if self.samples is not None:
if isinstance(self.samples, list):
return self.samples
elif callable(self.samples):
return self.samples
else:
raise TypeError(
"samples must be either a list of dicts or a callable returning a list"
)
def get_docs(self, dataset) -> Optional[Iterable[dict]]:
"""Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset)
if raw_docs is None:
return None
if self.process_docs is not None:
return self.process_docs(raw_docs)
return raw_docs
@property
def get_sampler(self):
from lm_eval.api import samplers
if isinstance(self.sampler, str):
return samplers.get_sampler(self.sampler)
elif callable(self.sampler):
return self.sampler
def init_sampler(
self, docs: list[dict], task: "Task", rnd=None, fewshot_indices=None
) -> "ContextSampler":
"""Initialize the sampler with the given documents and task."""
if rnd is None:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
)
return self.get_sampler(
docs,
task,
rnd=rnd,
fewshot_indices=fewshot_indices
if fewshot_indices
else self.fewshot_indices,
)
@dataclass
class DatasetConfig:
"""Encapsulates information about a dataset."""
path: Optional[str] = None
name: Optional[str] = None
kwargs: Optional[dict] = field(default_factory=dict)
custom: Optional[Callable] = None
metadata: Optional[dict] = None
@dataclass
class TaskConfig(dict):
# task naming/registry
task: Optional[str] = None
task_alias: Optional[str] = None
tag: Optional[Union[str, list]] = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
training_split: Optional[str] = None
validation_split: Optional[str] = None
test_split: Optional[str] = None
fewshot_split: Optional[str] = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str, None] = None
doc_to_audio: Union[Callable, str, None] = None
unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: Optional[dict] = None
# runtime configuration options
num_fewshot: Optional[int] = 0
# scoring options
metric_list: Optional[list] = None
output_type: OutputType = "generate_until"
generation_kwargs: Optional[dict] = None
repeats: int = 1
filter_list: Optional[list[dict]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
_metric_list: list[MetricConfig] = None
_filter_list: list[FilterConfig] = None
ds_cfg: DatasetConfig = None
fewshot_cfg: FewshotConfig = None
def __post_init__(self) -> None:
### ---setup generation kwargs--- ###
if self.generation_kwargs is not None:
if self.output_type != "generate_until":
eval_logger.warning(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
self.generation_kwargs["temperature"]
)
if "until" not in self.generation_kwargs:
eval_logger.warning(
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
)
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "generate_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {
"until": (
None
if self.fewshot_delimiter is None
else [self.fewshot_delimiter]
),
"do_sample": False,
"temperature": 0,
}
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
# ---setup dataset config--- #
self.ds_cfg = DatasetConfig(
path=self.dataset_path,
name=self.dataset_name,
kwargs=self.dataset_kwargs,
custom=self.custom_dataset,
metadata=self.metadata,
)
# ---setup fewshot config--- #
_fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {}
self.fewshot_cfg = FewshotConfig(
split=self.fewshot_split,
sampler=_fewshot_cfg.get("sampler", "default"),
samples=_fewshot_cfg.get("samples", None),
process_docs=_fewshot_cfg.get("process_docs", None),
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
)
@property
def get_metrics(self) -> list["MetricConfig"]:
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation,
get_metric,
get_metric_aggregation,
is_higher_better,
)
metrics = []
if self.metric_list is None:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info(
f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}"
)
metrics.extend(
MetricConfig(
name=metric_name,
fn=get_metric(metric_name),
aggregation_fn=get_metric_aggregation(metric_name),
higher_is_better=is_higher_better(metric_name),
)
for metric_name in _metric_list
)
else:
# ---------- 2. Process user-defined metrics from config ----------
for metric_config in self.metric_list:
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
_hf_evaluate_metric: bool = metric_config.get("hf_evaluate", False)
_metric_fn = None
_aggregation = None
if self.process_results is not None:
# User will compute metrics inside `process_results()`
_metric_name = None
_metric_fn_kwargs = {}
elif callable(metric_name):
# User passed a function object
_metric_name = metric_name.__name__
_metric_fn = metric_name.__call__
else:
# Normal: look up by name
_metric_name = metric_name
_metric_fn = get_metric(metric_name, _hf_evaluate_metric)
# ---------- 3. Decide how to aggregate examples ----------
if "aggregation" in metric_config:
if isinstance(_agg_name := metric_config["aggregation"], str):
_aggregation = get_aggregation(_agg_name)
elif callable(_agg_name): # noqa: E721
_aggregation = metric_config["aggregation"]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
_aggregation = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[_aggregation]}"
)
# ---------- 4. Determine “higher-is-better” semantics ----------
if "higher_is_better" in metric_config:
_higher_is_better = metric_config["higher_is_better"]
else:
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
_higher_is_better = is_higher_better(metric_name)
metrics.append(
MetricConfig(
name=_metric_name,
fn=_metric_fn,
kwargs=_metric_fn_kwargs,
aggregation_fn=_aggregation,
higher_is_better=_higher_is_better,
hf_evaluate=_hf_evaluate_metric,
)
)
return metrics
@property
def get_filters(self) -> list["FilterEnsemble"]:
from lm_eval.filters import build_filter_ensemble
if not self.filter_list:
eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats."
)
return [build_filter_ensemble("none", [["take_first", None]])]
else:
def _strip_fn(d: dict) -> dict:
return {k: v for k, v in d.items() if k != "function"}
configs = (
self.filter_list.values()
if isinstance(self.filter_list, dict)
else self.filter_list
)
return [
build_filter_ensemble(
filter_name=cfg["name"],
components=[[_strip_fn(f) for f in cfg["filter"]]],
)
for cfg in configs
]
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict:
"""Return a printable dict with Nones stripped and callables serialised.
:return: dict
A printable dictionary version of the TaskConfig object.
"""
cfg = asdict(self)
return {
k: [
{mk: maybe_serialize(mv, keep_callable) for mk, mv in md.items()}
for md in v
]
if k == "metric_list"
else maybe_serialize(v)
for k, v in cfg.items()
if v is not None
}
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Optional, Union
if TYPE_CHECKING:
from lm_eval.config.metric import MetricConfig
@dataclass
class TemplateConfig:
"""Encapsulates information about a template."""
template: str
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
description: str
context_prefix: str
prefix_delimiter: str
context_delimiter: str
answer_suffix: str
target_delimiter: str
choice_format: Optional[str]
choice_delimiter: Optional[str]
fewshot_delimiter: str
metric_list: Optional[Union[list[str], list["MetricConfig"]]] = field(
default_factory=lambda: ["acc", "acc_norm"]
)
@dataclass
class MCQTemplateConfig:
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
A. <doc_to_choice(doc)[0]>
B. <doc_to_choice(doc)[1]>
C. <doc_to_choice(doc)[2]>
D. <doc_to_choice(doc)[3]>
Answer:` doc_to_choice(doc)` for each choice.
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
template = "mcq"
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = "\n"
choice_format: Optional[str] = "letters"
choice_delimiter: Optional[str] = "\n"
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field(default_factory=lambda: ["acc"])
@dataclass
class ClozeTemplateConfig:
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
Answer:` <doc_to_target(doc)>`
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
template: str = "cloze"
description: str = ""
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = " "
choice_format: Optional[str] = None
choice_delimiter: Optional[str] = None
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field(
default_factory=lambda: ["acc", "acc_norm"]
)
from inspect import getsource
from typing import Any, Callable, Union
def serialize_callable(
value: Union[Callable, str], keep_callable=False
) -> Union[Callable, str]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
If serialization fails, returns the string representation.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
def maybe_serialize(
val: Union[Callable, Any], keep_callable=False
) -> Union[Callable, Any]:
"""Conditionally serializes a value if it is callable."""
return (
serialize_callable(val, keep_callable=keep_callable) if callable(val) else val
)
import unittest.mock as mock import unittest.mock as mock
from lm_eval.api.metrics import _bootstrap_internal_no_mp, mean from lm_eval.api.metrics import _bootstrap_internal_no_mp, mean
from lm_eval.api.task import ConfigurableTask, TaskConfig from lm_eval.api.task import ConfigurableTask
from lm_eval.config.task import TaskConfig
class MockConfigurableTask(ConfigurableTask): class MockConfigurableTask(ConfigurableTask):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment