Commit fb63ac0f authored by Baber's avatar Baber
Browse files

refactor configs to files

parent fb72009f
......@@ -257,7 +257,7 @@ class CachingLM:
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
return lm_attr
def fn(requests: list[Instance]) -> list[Instance]:
def fn(requests: list["Instance"]) -> list["Instance"]:
res = []
remaining_reqs = []
warned = False
......@@ -343,7 +343,7 @@ class TemplateLM(LM):
@abc.abstractmethod
def _loglikelihood_tokens(
self, requests: list[Instance], **kwargs
self, requests: list["Instance"], **kwargs
) -> list[tuple[float, bool]]:
pass
......@@ -375,7 +375,7 @@ class TemplateLM(LM):
return context_enc, continuation_enc
def loglikelihood(
self, requests: list[Instance], disable_tqdm: bool = False
self, requests: list["Instance"], disable_tqdm: bool = False
) -> list[tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
......
This diff is collapsed.
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Callable, List, Optional
from lm_eval.api.registry import get_aggregation, is_higher_better
@dataclass
class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
aggregation_fn: Optional[Callable] = None
higher_is_better: bool = True
hf_evaluate: bool = False
is_elementwise: bool = True
@cached_property
def metric_name(self) -> str:
return self.name
@cached_property
def aggregation(self) -> Callable:
if self.aggregation_fn is None:
return get_aggregation(self.name)
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool:
if self.higher_is_better is None:
return is_higher_better(self.name)
return self.higher_is_better
def compute_metric(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs})
def compute_aggregation(self, values: List[Any]) -> Any:
"""Computes the aggregation of the metric values."""
if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(values)
import logging
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation,
get_metric,
get_metric_aggregation,
is_higher_better,
)
from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize
from lm_eval.filters import build_filter_ensemble
if TYPE_CHECKING:
from lm_eval.api.samplers import ContextSampler
from lm_eval.api.task import Task, eval_logger
eval_logger = logging.getLogger(__name__)
@dataclass
class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats: int = 1
metric_fn: Union[str, Callable] = "pass@N"
kwargs: Optional[dict] = None
@dataclass
class FilterConfig:
"""Encapsulates information about a single filter."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
@dataclass
class FewshotConfig:
num: int = 0
split: Optional[str] = None
sampler: Union[str, Callable] = "default"
samples: Union[Callable[[], list[dict]], list[dict], None] = None
process_docs: Optional[Callable[[list[dict]], Iterable[dict]]] = None
fewshot_indices: Optional[list[int]] = None
rnd: int = field(init=False, default=False)
def __post_init__(self) -> None:
if self.samples is not None and not (
isinstance(self.samples, list) or callable(self.samples)
):
raise TypeError(
"samples must be either list[dict] or callable returning list[dict]"
)
if self.split is not None and self.samples is not None:
eval_logger.warning(
"Both split and samples are configured; split will take precedence"
)
@property
def has_source(self) -> bool:
"""Check if any fewshot source is configured."""
return self.split is not None or self.samples is not None
def _get_raw_docs(
self, dataset
) -> Union[list[dict], Callable[[], Iterable[dict]], None]:
"""Get raw documents from configured source."""
if self.split is not None:
return dataset[self.split]
if self.samples is not None:
if isinstance(self.samples, list):
return self.samples
elif callable(self.samples):
return self.samples
else:
raise TypeError(
"samples must be either a list of dicts or a callable returning a list"
)
def get_docs(self, dataset) -> Optional[Iterable[dict]]:
"""Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset)
if raw_docs is None:
return None
if self.process_docs is not None:
return self.process_docs(raw_docs)
return raw_docs
@property
def get_sampler(self):
from lm_eval.api import samplers
if isinstance(self.sampler, str):
return samplers.get_sampler(self.sampler)
elif callable(self.sampler):
return self.sampler
def init_sampler(
self, docs: list[dict], task: "Task", rnd=None, fewshot_indices=None
) -> "ContextSampler":
"""Initialize the sampler with the given documents and task."""
if rnd is None:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
)
return self.get_sampler(
docs,
task,
rnd=rnd,
fewshot_indices=fewshot_indices
if fewshot_indices
else self.fewshot_indices,
)
@dataclass
class DatasetConfig:
"""Encapsulates information about a dataset."""
path: Optional[str] = None
name: Optional[str] = None
kwargs: Optional[dict] = field(default_factory=dict)
custom: Optional[Callable] = None
metadata: Optional[dict] = None
@dataclass
class TaskConfig(dict):
# task naming/registry
task: Optional[str] = None
task_alias: Optional[str] = None
tag: Optional[Union[str, list]] = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
training_split: Optional[str] = None
validation_split: Optional[str] = None
test_split: Optional[str] = None
fewshot_split: Optional[str] = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str, None] = None
doc_to_audio: Union[Callable, str, None] = None
unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: Optional[dict] = None
# runtime configuration options
num_fewshot: Optional[int] = 0
# scoring options
metric_list: Optional[list] = None
output_type: OutputType = "generate_until"
generation_kwargs: Optional[dict] = None
repeats: int = 1
filter_list: Optional[list[dict]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
_metric_list: list[MetricConfig] = None
_filter_list: list[FilterConfig] = None
ds_cfg: DatasetConfig = None
fewshot_cfg: FewshotConfig = None
def __post_init__(self) -> None:
### ---setup generation kwargs--- ###
if self.generation_kwargs is not None:
if self.output_type != "generate_until":
eval_logger.warning(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
self.generation_kwargs["temperature"]
)
if "until" not in self.generation_kwargs:
eval_logger.warning(
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
)
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "generate_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {
"until": (
None
if self.fewshot_delimiter is None
else [self.fewshot_delimiter]
),
"do_sample": False,
"temperature": 0,
}
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
# ---setup dataset config--- #
self.ds_cfg = DatasetConfig(
path=self.dataset_path,
name=self.dataset_name,
kwargs=self.dataset_kwargs,
custom=self.custom_dataset,
metadata=self.metadata,
)
# ---setup fewshot config--- #
_fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {}
self.fewshot_cfg = FewshotConfig(
split=self.fewshot_split,
sampler=_fewshot_cfg.get("sampler", "default"),
samples=_fewshot_cfg.get("samples", None),
process_docs=_fewshot_cfg.get("process_docs", None),
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
)
@property
def get_metrics(self) -> list["MetricConfig"]:
metrics = []
if self.metric_list is None:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info(
f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}"
)
metrics.extend(
MetricConfig(
name=metric_name,
fn=get_metric(metric_name),
aggregation_fn=get_metric_aggregation(metric_name),
higher_is_better=is_higher_better(metric_name),
)
for metric_name in _metric_list
)
else:
# ---------- 2. How will the samples be evaluated ----------
for metric_config in self.metric_list:
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
_hf_evaluate_metric: bool = metric_config.get("hf_evaluate", False)
_metric_fn = None
_aggregation = None
if self.process_results is not None:
# User will compute metrics inside `process_results()`
_metric_name = None
_metric_fn_kwargs = {}
elif callable(metric_name):
# User passed a function object
_metric_name = metric_name.__name__
_metric_fn = metric_name.__call__
else:
# Normal: look up by name
_metric_name = get_metric(metric_name, _hf_evaluate_metric)
# ---------- 3. Decide how to aggregate examples ----------
if "aggregation" in metric_config:
if isinstance(_agg_name := metric_config["aggregation"], str):
_aggregation = get_aggregation(_agg_name)
elif callable(_agg_name): # noqa: E721
_aggregation = metric_config["aggregation"]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
_aggregation = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[_aggregation]}"
)
# ---------- 4. Determine “higher-is-better” semantics ----------
if "higher_is_better" in metric_config:
_higher_is_better = metric_config["higher_is_better"]
else:
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
_higher_is_better = is_higher_better(metric_name)
metrics.append(
MetricConfig(
name=_metric_name,
fn=_metric_fn,
kwargs=_metric_fn_kwargs,
aggregation_fn=_aggregation,
higher_is_better=_higher_is_better,
hf_evaluate=_hf_evaluate_metric,
)
)
return metrics
@property
def get_filters(self) -> list["FilterEnsemble"]:
if not self.filter_list:
eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats."
)
return [build_filter_ensemble("none", [["take_first", None]])]
else:
def _strip_fn(d: dict) -> dict:
return {k: v for k, v in d.items() if k != "function"}
configs = (
self.filter_list.values()
if isinstance(self.filter_list, dict)
else self.filter_list
)
return [
build_filter_ensemble(
filter_name=cfg["name"],
components=[[_strip_fn(f) for f in cfg["filter"]]],
)
for cfg in configs
]
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict:
"""Return a printable dict with Nones stripped and callables serialised.
:return: dict
A printable dictionary version of the TaskConfig object.
"""
cfg = asdict(self)
return {
k: [
{mk: maybe_serialize(mv, keep_callable) for mk, mv in md.items()}
for md in v
]
if k == "metric_list"
else maybe_serialize(v)
for k, v in cfg.items()
if v is not None
}
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Optional, Union
if TYPE_CHECKING:
from lm_eval.config.metric import MetricConfig
@dataclass
class TemplateConfig:
"""Encapsulates information about a template."""
template: str
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
description: str
context_prefix: str
prefix_delimiter: str
context_delimiter: str
answer_suffix: str
target_delimiter: str
choice_format: Optional[str]
choice_delimiter: Optional[str]
fewshot_delimiter: str
metric_list: Optional[Union[list[str], list["MetricConfig"]]] = field(
default_factory=lambda: ["acc", "acc_norm"]
)
@dataclass
class MCQTemplateConfig:
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
A. <doc_to_choice(doc)[0]>
B. <doc_to_choice(doc)[1]>
C. <doc_to_choice(doc)[2]>
D. <doc_to_choice(doc)[3]>
Answer:` doc_to_choice(doc)` for each choice.
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
template = "mcq"
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = "\n"
choice_format: Optional[str] = "letters"
choice_delimiter: Optional[str] = "\n"
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field(default_factory=lambda: ["acc"])
@dataclass
class ClozeTemplateConfig:
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
Answer:` <doc_to_target(doc)>`
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
template: str = "cloze"
description: str = ""
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = " "
choice_format: Optional[str] = None
choice_delimiter: Optional[str] = None
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field(
default_factory=lambda: ["acc", "acc_norm"]
)
from inspect import getsource
from typing import Any, Callable, Union
def serialize_callable(
value: Union[Callable, str], keep_callable=False
) -> Union[Callable, str]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
If serialization fails, returns the string representation.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
def maybe_serialize(
val: Union[Callable, Any], keep_callable=False
) -> Union[Callable, Any]:
"""Conditionally serializes a value if it is callable."""
return (
serialize_callable(val, keep_callable=keep_callable) if callable(val) else val
)
from lm_eval.api.task import ConfigurableTask, TaskConfig
from lm_eval.api.task import ConfigurableTask
from lm_eval.config.task import TaskConfig
class MockConfigurableTask(ConfigurableTask):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment