Commit 1b5c6f88 authored by Baber's avatar Baber
Browse files

add MetricConfig

parent 6b3f3f7e
import abc
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, Callable, List, Optional, Union
......@@ -84,7 +83,7 @@ class GroupConfig(dict):
return str(value)
class ConfigurableGroup(abc.ABC):
class ConfigurableGroup:
def __init__(
self,
config: Optional[dict] = None,
......
......@@ -14,10 +14,23 @@ class Instance:
arguments: tuple
idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None)
default_factory=lambda: (None, None, None),
metadata=dict(
description="Metadata tuple containing task name, document ID, and number of repeats."
),
)
resps: list = field(
default_factory=list,
metadata=dict(
description="List of responses from the model for this instance."
),
)
filtered_resps: dict = field(
default_factory=dict,
metadata=dict(
description="List of filtered responses for this instance, keyed by filter name."
),
)
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
# initialized after init
task_name: Optional[str] = None
......@@ -29,7 +42,7 @@ class Instance:
self.task_name, self.doc_id, self.repeats = self.metadata
@property
def args(self):
def args(self) -> tuple:
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
......
......@@ -7,7 +7,6 @@ from collections.abc import Iterable
from typing import List
import numpy as np
import sacrebleu
from lm_eval.api.registry import register_aggregation, register_metric
......@@ -89,6 +88,8 @@ def bleu(items):
Higher is better
"""
import sacrebleu
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
......@@ -104,6 +105,8 @@ def chrf(items):
Higher is better # TODO I think
"""
import sacrebleu
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
......@@ -120,6 +123,8 @@ def ter(items):
Lower is better
"""
import sacrebleu
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
......
import logging
from typing import Callable, Dict, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
import evaluate as hf_evaluate
from lm_eval.api.model import LM
if TYPE_CHECKING:
from lm_eval.api.model import LM
eval_logger = logging.getLogger(__name__)
......@@ -12,6 +11,8 @@ MODEL_REGISTRY = {}
def register_model(*names):
from lm_eval.api.model import LM
# either pass a list or a single alias.
# function receives them as a tuple of strings
......@@ -31,7 +32,7 @@ def register_model(*names):
return decorate
def get_model(model_name):
def get_model(model_name: str) -> type["LM"]:
try:
return MODEL_REGISTRY[model_name]
except KeyError:
......@@ -46,7 +47,7 @@ ALL_TASKS = set()
func2task_index = {}
def register_task(name):
def register_task(name: str):
def decorate(fn):
assert name not in TASK_REGISTRY, (
f"task named '{name}' conflicts with existing registered task!"
......@@ -120,7 +121,7 @@ def register_metric(**args):
return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
def get_metric(name: str, hf_evaluate_metric=False) -> Optional[Callable]:
if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
......@@ -130,6 +131,8 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
)
try:
import evaluate as hf_evaluate
metric_object = hf_evaluate.load(name)
return metric_object.compute
except Exception:
......@@ -150,21 +153,21 @@ def register_aggregation(name: str):
return decorate
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
def get_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name) -> bool:
def is_higher_better(metric_name) -> Optional[bool]:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
......
import logging
import warnings
from functools import partial
from typing import TYPE_CHECKING, Iterable, Optional, Union
from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
import datasets
......@@ -181,7 +181,7 @@ class ContextSampler:
return chat_history
def sample(self, n: int):
def sample(self, n: int) -> Sequence[dict]:
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
......@@ -190,7 +190,7 @@ class ContextSampler:
class FirstNSampler(ContextSampler):
def sample(self, n: int) -> None:
def sample(self, n: int) -> Sequence[dict]:
"""
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
......
......@@ -6,6 +6,7 @@ import re
from collections.abc import Callable
from copy import deepcopy
from dataclasses import asdict, dataclass
from functools import cached_property
from inspect import getsource
from typing import (
Any,
......@@ -23,6 +24,7 @@ from typing import (
import datasets
import numpy as np
from tqdm import tqdm
from typing_extensions import deprecated
from lm_eval import utils
from lm_eval.api import samplers
......@@ -51,6 +53,43 @@ ALL_OUTPUT_TYPES = [
eval_logger = logging.getLogger(__name__)
@dataclass
class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
aggregation_fn: Optional[Callable] = None
higher_is_better: bool = True
hf_evaluate: bool = False
@cached_property
def metric_names(self) -> str:
return self.name
@cached_property
def aggregation(self) -> Callable:
if self.aggregation_fn is None:
return get_aggregation(self.name)
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool:
if self.higher_is_better is None:
return is_higher_better(self.name)
return self.higher_is_better
@dataclass
class FilterConfig:
"""Encapsulates information about a filter."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
@dataclass
class TaskConfig(dict):
# task naming/registry
......@@ -133,6 +172,93 @@ class TaskConfig(dict):
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
if self.metric_list is not None:
for metric_config in self.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
def get_metrics(self) -> list["MetricConfig"]:
metrics = []
if self.metric_list is None:
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
metrics.extend(
MetricConfig(
name=metric_name,
fn=get_metric(metric_name),
aggregation_fn=get_metric_aggregation(metric_name),
higher_is_better=is_higher_better(metric_name),
)
for metric_name in _metric_list
)
else:
for metric_config in self.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
_hf_evaluate_metric: bool = metric_config.get("hf_evaluate", False)
_metric_fn = None
_aggregation = None
if self.process_results is not None:
# User will compute metrics inside `process_results()`
_metric_name = None
_metric_fn_kwargs = {}
elif callable(metric_name):
# User passed a function object
_metric_name = metric_name.__name__
_metric_fn = metric_name.__call__
else:
# Normal: look up by name
_metric_name = get_metric(metric_name, _hf_evaluate_metric)
# ---------- 3. Decide how to aggregate examples ----------
if "aggregation" in metric_config:
if isinstance(_agg_name := metric_config["aggregation"], str):
_aggregation = get_aggregation(_agg_name)
elif callable(_agg_name): # noqa: E721
_aggregation = metric_config["aggregation"]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
_aggregation = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[_aggregation]}"
)
# ---------- 4. Determine “higher-is-better” semantics ----------
if "higher_is_better" in metric_config:
_higher_is_better = metric_config["higher_is_better"]
else:
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
_higher_is_better = is_higher_better(metric_name)
metrics.append(
MetricConfig(
name=_metric_name,
fn=_metric_fn,
kwargs=_metric_fn_kwargs,
aggregation_fn=_aggregation,
higher_is_better=_higher_is_better,
hf_evaluate=_hf_evaluate_metric,
)
)
return metrics
def __getitem__(self, item):
return getattr(self, item)
......@@ -534,7 +660,7 @@ class Task(abc.ABC):
"""
pass
@abc.abstractmethod
@deprecated("not used anymore")
def aggregation(self):
"""
:returns: {str: [metric_score] -> float}
......@@ -543,7 +669,7 @@ class Task(abc.ABC):
"""
pass
@abc.abstractmethod
@deprecated("not used anymore")
def higher_is_better(self):
"""
:returns: {str: bool}
......@@ -661,23 +787,13 @@ class Task(abc.ABC):
Parameters:
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
"""
(
self._metric_fn_list,
self._aggregation_list,
self._metric_fn_kwargs,
self._higher_is_better,
) = ({}, {}, {}, {})
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self._metric_fn_kwargs[metric_name] = {}
if not isinstance(self, ConfigurableTask):
self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
self.aggregation = lambda: {
metric_name: get_metric_aggregation(metric_name)
}
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)
# if not isinstance(self, ConfigurableTask):
# self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
# self.aggregation = lambda: {
# metric_name: get_metric_aggregation(metric_name)
# }
setattr(self._config, "metric_list", [MetricConfig(name=metric_name)])
setattr(self._config, "process_results", lambda *args: {"bypass": 0})
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
self.fewshot_rnd = random.Random(seed)
......@@ -739,7 +855,7 @@ class ConfigurableTask(Task):
cache_dir=None,
download_mode=None,
config: Optional[dict] = None,
) -> None: # TODO no super() call here
) -> None:
# Get pre-configured attributes
self._config = self.CONFIG
......@@ -784,83 +900,7 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name
self._metric_fn_list = {}
self._metric_fn_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {}
if self.config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ?
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = get_metric_aggregation(
metric_name
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
else:
for metric_config in self.config.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"]
kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
hf_evaluate_metric = (
"hf_evaluate" in metric_config
and metric_config["hf_evaluate"] is True
)
if self.config.process_results is not None:
self._metric_fn_list[metric_name] = None
self._metric_fn_kwargs[metric_name] = {}
elif callable(metric_name):
metric_fn = metric_name.__call__
metric_name = metric_name.__name__
self._metric_fn_list[metric_name] = metric_fn
self._metric_fn_kwargs[metric_name] = kwargs
else:
self._metric_fn_list[metric_name] = get_metric(
metric_name, hf_evaluate_metric
)
self._metric_fn_kwargs[metric_name] = kwargs
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name): # noqa: E721
self._aggregation_list[metric_name] = metric_config[
"aggregation"
]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
)
self._aggregation_list[metric_name] = metric_agg
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
]
else:
eval_logger.warning(
f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self.metric_list: list[MetricConfig] = self._config.get_metrics()
self.download(self.config.dataset_kwargs)
self._training_docs = None
......@@ -868,17 +908,23 @@ class ConfigurableTask(Task):
if self.config.filter_list is not None:
self._filters = []
for filter_config in self.config.filter_list:
filter_name = filter_config["name"]
filter_functions = filter_config["filter"]
components = []
for function in filter_functions:
kwargs = {
key: function[key] for key in function if key != "function"
}
components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline)
if isinstance(self.config.filter_list, dict):
for filter_config in self.config.filter_list:
self._filters.append(
build_filter_ensemble(
filter_config["name"],
[
[
{
key: function[key]
for key in function
if key != "function"
}
]
for function in filter_config["filter"]
],
)
)
else:
# TODO: handle repeats in a more general way rather than just discarding
eval_logger.debug(
......@@ -1476,7 +1522,7 @@ class ConfigurableTask(Task):
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
if "acc_mutual_info" in [m.metric_names for m in self.metric_list]:
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
......@@ -1543,7 +1589,7 @@ class ConfigurableTask(Task):
return self.config.process_results(doc, results)
result_dict = {}
use_metric = list(self._metric_fn_list.keys())
use_metric = list(m.metric_names for m in self.metric_list)
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
......@@ -1579,10 +1625,7 @@ class ConfigurableTask(Task):
choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices])
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys()
):
if 2 * len(choices) == len(lls) and "acc_mutual_info" in use_metric:
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
# as we extend the args list with unconditional ("", continuation) pairs
......@@ -1667,12 +1710,12 @@ class ConfigurableTask(Task):
gold = list(gold)
# TODO: handle this better
elif type(gold) is not type(result) and not (
"bypass" in self._metric_fn_list.keys() or isinstance(result, list)
"bypass" in use_metric or isinstance(result, list)
):
# cast gold to the same type as result
gold = type(result)(gold)
for metric in self._metric_fn_list.keys():
for metric in self.metric_list:
if self.multiple_target:
# in the case where we have multiple targets,
# return true if any are true
......@@ -1682,28 +1725,26 @@ class ConfigurableTask(Task):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold = [gold]
if metric == "exact_match":
if metric.name == "exact_match":
result = [result for _ in range(len(gold))]
scores = self._metric_fn_list[metric](
scores = metric.fn(
references=gold,
predictions=result,
**self._metric_fn_kwargs[metric],
**metric.kwargs,
)[metric]
result_score = 1.0 if scores > 0.0 else 0.0
else:
for gold_option in gold:
try:
result_score = self._metric_fn_list[metric](
result_score = metric.fn(
references=[gold_option],
predictions=[result],
**self._metric_fn_kwargs[metric],
**metric.kwargs,
)
except (
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold_option, result]
)
result_score = metric.fn([gold_option, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
......@@ -1714,13 +1755,13 @@ class ConfigurableTask(Task):
result_score = 0.0
else:
try:
result_score = self._metric_fn_list[metric](
result_score = metric.fn(
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
**metric.kwargs,
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
result_score = metric.fn([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function
......@@ -1737,10 +1778,10 @@ class ConfigurableTask(Task):
return result_dict
def aggregation(self) -> dict:
return self._aggregation_list
return {k.name: k.aggregation_fn for k in self.metric_list}
def higher_is_better(self) -> dict:
return self._higher_is_better
return {k.name: k.higher_is_better for k in self.metric_list}
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
......
......@@ -272,7 +272,7 @@ def simple_evaluate(
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
def _adjust_config(task_dict):
def _adjust_config(task_dict: dict[str, "Task"]) -> dict[str, "Task"]:
adjusted_task_dict = {}
for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict):
......
......@@ -12,7 +12,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr,
stderr_for_metric,
)
from lm_eval.api.task import Task
from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.utils import positional_deprecated
......@@ -58,7 +58,7 @@ class TaskOutput:
group_alias=None,
is_group=None,
):
self.task = task
self.task: Union[Task, ConfigurableTask] = task
self.task_config = task_config
self.task_name = task_name
self.group_name = group_name
......
from functools import partial
from typing import List
from typing import List, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter
......@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
def build_filter_ensemble(
filter_name: str, components: List[List[str]]
filter_name: str, components: list[Union[list[dict], list[str]]]
) -> FilterEnsemble:
"""
Create a filtering pipeline.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment