"test/verify/test_reduce_noop_add.cpp" did not exist on "8d21fdc9dd58e62192d9408132585eea94bbf79b"
Commit 28c78d30 authored by Baber's avatar Baber
Browse files

add MetricConfig

parent de496b80
......@@ -485,6 +485,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if results is not None:
if args.log_samples:
samples = results.pop("samples")
# TODO: fix this!
results["higher_is_better"] = {
k: True for k, v in results["higher_is_better"].items()
}
dumped = json.dumps(
results, indent=2, default=handle_non_serializable, ensure_ascii=False
)
......
import abc
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, Callable, List, Optional, Union
......@@ -84,7 +83,7 @@ class GroupConfig(dict):
return str(value)
class ConfigurableGroup(abc.ABC):
class ConfigurableGroup:
def __init__(
self,
config: Optional[dict] = None,
......
......@@ -14,10 +14,23 @@ class Instance:
arguments: tuple
idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None)
default_factory=lambda: (None, None, None),
metadata=dict(
description="Metadata tuple containing task name, document ID, and number of repeats."
),
)
resps: list = field(
default_factory=list,
metadata=dict(
description="List of responses from the model for this instance."
),
)
filtered_resps: dict = field(
default_factory=dict,
metadata=dict(
description="List of filtered responses for this instance, keyed by filter name."
),
)
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
# initialized after init
task_name: Optional[str] = None
......@@ -29,7 +42,7 @@ class Instance:
self.task_name, self.doc_id, self.repeats = self.metadata
@property
def args(self):
def args(self) -> tuple:
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
......
......@@ -8,7 +8,6 @@ from collections.abc import Iterable
from typing import Callable, List, Optional, Sequence, TypeVar
import numpy as np
import sacrebleu
from lm_eval.api.registry import register_aggregation, register_metric
......@@ -92,6 +91,8 @@ def bleu(items):
Higher is better
"""
import sacrebleu
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
......@@ -107,6 +108,8 @@ def chrf(items):
Higher is better # TODO I think
"""
import sacrebleu
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
......@@ -123,6 +126,8 @@ def ter(items):
Lower is better
"""
import sacrebleu
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
......
import logging
from typing import Callable, Dict, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
import evaluate as hf_evaluate
from lm_eval.api.model import LM
if TYPE_CHECKING:
from lm_eval.api.model import LM
eval_logger = logging.getLogger(__name__)
......@@ -12,6 +11,8 @@ MODEL_REGISTRY = {}
def register_model(*names):
from lm_eval.api.model import LM
# either pass a list or a single alias.
# function receives them as a tuple of strings
......@@ -31,7 +32,7 @@ def register_model(*names):
return decorate
def get_model(model_name):
def get_model(model_name: str) -> type["LM"]:
try:
return MODEL_REGISTRY[model_name]
except KeyError:
......@@ -46,7 +47,7 @@ ALL_TASKS = set()
func2task_index = {}
def register_task(name):
def register_task(name: str):
def decorate(fn):
assert name not in TASK_REGISTRY, (
f"task named '{name}' conflicts with existing registered task!"
......@@ -120,7 +121,7 @@ def register_metric(**args):
return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
def get_metric(name: str, hf_evaluate_metric=False) -> Optional[Callable]:
if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
......@@ -130,6 +131,8 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
)
try:
import evaluate as hf_evaluate
metric_object = hf_evaluate.load(name)
return metric_object.compute
except Exception:
......@@ -150,21 +153,21 @@ def register_aggregation(name: str):
return decorate
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
def get_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name) -> bool:
def is_higher_better(metric_name) -> Optional[bool]:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
......
import logging
import warnings
from functools import partial
from typing import TYPE_CHECKING, Iterable, Optional, Union
from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
import datasets
......@@ -181,7 +181,7 @@ class ContextSampler:
return chat_history
def sample(self, n: int):
def sample(self, n: int) -> Sequence[dict]:
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
......@@ -190,7 +190,7 @@ class ContextSampler:
class FirstNSampler(ContextSampler):
def sample(self, n: int) -> None:
def sample(self, n: int) -> Sequence[dict]:
"""
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
......
......@@ -6,6 +6,7 @@ import re
from collections.abc import Callable
from copy import deepcopy
from dataclasses import asdict, dataclass
from functools import cached_property
from inspect import getsource
from typing import (
Any,
......@@ -23,6 +24,7 @@ from typing import (
import datasets
import numpy as np
from tqdm import tqdm
from typing_extensions import deprecated
from lm_eval import utils
from lm_eval.api import samplers
......@@ -51,6 +53,43 @@ ALL_OUTPUT_TYPES = [
eval_logger = logging.getLogger(__name__)
@dataclass
class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
aggregation_fn: Optional[Callable] = None
higher_is_better: bool = True
hf_evaluate: bool = False
@cached_property
def metric_names(self) -> str:
return self.name
@cached_property
def aggregation(self) -> Callable:
if self.aggregation_fn is None:
return get_aggregation(self.name)
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool:
if self.higher_is_better is None:
return is_higher_better(self.name)
return self.higher_is_better
@dataclass
class FilterConfig:
"""Encapsulates information about a single filter."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
@dataclass
class TaskConfig(dict):
# task naming/registry
......@@ -99,6 +138,8 @@ class TaskConfig(dict):
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
_metric_list = None
_filter_list = None
def __post_init__(self) -> None:
if self.generation_kwargs is not None:
......@@ -133,6 +174,93 @@ class TaskConfig(dict):
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
if self.metric_list is not None:
for metric_config in self.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
def get_metrics(self) -> list["MetricConfig"]:
metrics = []
if self.metric_list is None:
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
metrics.extend(
MetricConfig(
name=metric_name,
fn=get_metric(metric_name),
aggregation_fn=get_metric_aggregation(metric_name),
higher_is_better=is_higher_better(metric_name),
)
for metric_name in _metric_list
)
else:
for metric_config in self.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
_hf_evaluate_metric: bool = metric_config.get("hf_evaluate", False)
_metric_fn = None
_aggregation = None
if self.process_results is not None:
# User will compute metrics inside `process_results()`
_metric_name = None
_metric_fn_kwargs = {}
elif callable(metric_name):
# User passed a function object
_metric_name = metric_name.__name__
_metric_fn = metric_name.__call__
else:
# Normal: look up by name
_metric_name = get_metric(metric_name, _hf_evaluate_metric)
# ---------- 3. Decide how to aggregate examples ----------
if "aggregation" in metric_config:
if isinstance(_agg_name := metric_config["aggregation"], str):
_aggregation = get_aggregation(_agg_name)
elif callable(_agg_name): # noqa: E721
_aggregation = metric_config["aggregation"]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
_aggregation = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[_aggregation]}"
)
# ---------- 4. Determine “higher-is-better” semantics ----------
if "higher_is_better" in metric_config:
_higher_is_better = metric_config["higher_is_better"]
else:
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
_higher_is_better = is_higher_better(metric_name)
metrics.append(
MetricConfig(
name=_metric_name,
fn=_metric_fn,
kwargs=_metric_fn_kwargs,
aggregation_fn=_aggregation,
higher_is_better=_higher_is_better,
hf_evaluate=_hf_evaluate_metric,
)
)
return metrics
def __getitem__(self, item):
return getattr(self, item)
......@@ -534,7 +662,7 @@ class Task(abc.ABC):
"""
pass
@abc.abstractmethod
@deprecated("not used anymore")
def aggregation(self):
"""
:returns: {str: [metric_score] -> float}
......@@ -543,7 +671,7 @@ class Task(abc.ABC):
"""
pass
@abc.abstractmethod
@deprecated("not used anymore")
def higher_is_better(self):
"""
:returns: {str: bool}
......@@ -661,23 +789,13 @@ class Task(abc.ABC):
Parameters:
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
"""
(
self._metric_fn_list,
self._aggregation_list,
self._metric_fn_kwargs,
self._higher_is_better,
) = ({}, {}, {}, {})
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self._metric_fn_kwargs[metric_name] = {}
if not isinstance(self, ConfigurableTask):
self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
self.aggregation = lambda: {
metric_name: get_metric_aggregation(metric_name)
}
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)
# if not isinstance(self, ConfigurableTask):
# self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
# self.aggregation = lambda: {
# metric_name: get_metric_aggregation(metric_name)
# }
setattr(self._config, "metric_list", [MetricConfig(name=metric_name)])
setattr(self._config, "process_results", lambda *args: {"bypass": 0})
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
self.fewshot_rnd = random.Random(seed)
......@@ -739,7 +857,7 @@ class ConfigurableTask(Task):
cache_dir=None,
download_mode=None,
config: Optional[dict] = None,
) -> None: # TODO no super() call here
) -> None:
# Get pre-configured attributes
self._config = self.CONFIG
......@@ -784,83 +902,7 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name
self._metric_fn_list = {}
self._metric_fn_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {}
if self.config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ?
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = get_metric_aggregation(
metric_name
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
else:
for metric_config in self.config.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"]
kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
hf_evaluate_metric = (
"hf_evaluate" in metric_config
and metric_config["hf_evaluate"] is True
)
if self.config.process_results is not None:
self._metric_fn_list[metric_name] = None
self._metric_fn_kwargs[metric_name] = {}
elif callable(metric_name):
metric_fn = metric_name.__call__
metric_name = metric_name.__name__
self._metric_fn_list[metric_name] = metric_fn
self._metric_fn_kwargs[metric_name] = kwargs
else:
self._metric_fn_list[metric_name] = get_metric(
metric_name, hf_evaluate_metric
)
self._metric_fn_kwargs[metric_name] = kwargs
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name): # noqa: E721
self._aggregation_list[metric_name] = metric_config[
"aggregation"
]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
)
self._aggregation_list[metric_name] = metric_agg
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
]
else:
eval_logger.warning(
f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self.metric_list: list[MetricConfig] = self._config.get_metrics()
self.download(self.config.dataset_kwargs)
self._training_docs = None
......@@ -868,17 +910,23 @@ class ConfigurableTask(Task):
if self.config.filter_list is not None:
self._filters = []
for filter_config in self.config.filter_list:
filter_name = filter_config["name"]
filter_functions = filter_config["filter"]
components = []
for function in filter_functions:
kwargs = {
key: function[key] for key in function if key != "function"
}
components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline)
if isinstance(self.config.filter_list, dict):
for filter_config in self.config.filter_list:
self._filters.append(
build_filter_ensemble(
filter_config["name"],
[
[
{
key: function[key]
for key in function
if key != "function"
}
]
for function in filter_config["filter"]
],
)
)
else:
# TODO: handle repeats in a more general way rather than just discarding
eval_logger.debug(
......@@ -1297,7 +1345,7 @@ class ConfigurableTask(Task):
return doc[doc_to_text]
else:
text_string = utils.apply_template(doc_to_text, doc)
if text_string.isdigit() and self._config.doc_to_choice is not None:
if text_string.isdigit() and self.config.doc_to_choice is not None:
return ast.literal_eval(text_string)
else:
return text_string
......@@ -1333,7 +1381,7 @@ class ConfigurableTask(Task):
return doc[doc_to_target]
else:
target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit() and self._config.doc_to_choice is not None:
if target_string.isdigit() and self.config.doc_to_choice is not None:
return ast.literal_eval(target_string)
elif (
len(target_string) >= 2
......@@ -1480,7 +1528,7 @@ class ConfigurableTask(Task):
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
if "acc_mutual_info" in [m.metric_names for m in self.metric_list]:
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
......@@ -1547,7 +1595,7 @@ class ConfigurableTask(Task):
return self.config.process_results(doc, results)
result_dict = {}
use_metric = list(self._metric_fn_list.keys())
use_metric = list(m.metric_names for m in self.metric_list)
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
......@@ -1583,10 +1631,7 @@ class ConfigurableTask(Task):
choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices])
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys()
):
if 2 * len(choices) == len(lls) and "acc_mutual_info" in use_metric:
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
# as we extend the args list with unconditional ("", continuation) pairs
......@@ -1671,12 +1716,12 @@ class ConfigurableTask(Task):
gold = list(gold)
# TODO: handle this better
elif type(gold) is not type(result) and not (
"bypass" in self._metric_fn_list.keys() or isinstance(result, list)
"bypass" in use_metric or isinstance(result, list)
):
# cast gold to the same type as result
gold = type(result)(gold)
for metric in self._metric_fn_list.keys():
for metric in self.metric_list:
if self.multiple_target:
# in the case where we have multiple targets,
# return true if any are true
......@@ -1686,28 +1731,26 @@ class ConfigurableTask(Task):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold = [gold]
if metric == "exact_match":
if metric.name == "exact_match":
result = [result for _ in range(len(gold))]
scores = self._metric_fn_list[metric](
scores = metric.fn(
references=gold,
predictions=result,
**self._metric_fn_kwargs[metric],
**metric.kwargs,
)[metric]
result_score = 1.0 if scores > 0.0 else 0.0
else:
for gold_option in gold:
try:
result_score = self._metric_fn_list[metric](
result_score = metric.fn(
references=[gold_option],
predictions=[result],
**self._metric_fn_kwargs[metric],
**metric.kwargs,
)
except (
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold_option, result]
)
result_score = metric.fn([gold_option, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
......@@ -1718,13 +1761,13 @@ class ConfigurableTask(Task):
result_score = 0.0
else:
try:
result_score = self._metric_fn_list[metric](
result_score = metric.fn(
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
**metric.kwargs,
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
result_score = metric.fn([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function
......@@ -1741,10 +1784,10 @@ class ConfigurableTask(Task):
return result_dict
def aggregation(self) -> dict:
return self._aggregation_list
return {k.name: k.aggregation_fn for k in self.metric_list}
def higher_is_better(self) -> dict:
return self._higher_is_better
return {k.name: k.higher_is_better for k in self.metric_list}
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
......
......@@ -287,7 +287,7 @@ def simple_evaluate(
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
def _adjust_config(task_dict):
def _adjust_config(task_dict: dict[str, "Task"]) -> dict[str, "Task"]:
adjusted_task_dict = {}
for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict):
......
......@@ -12,7 +12,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr,
stderr_for_metric,
)
from lm_eval.api.task import Task
from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.utils import positional_deprecated
......@@ -58,7 +58,7 @@ class TaskOutput:
group_alias=None,
is_group=None,
):
self.task = task
self.task: Union[Task, ConfigurableTask] = task
self.task_config = task_config
self.task_name = task_name
self.group_name = group_name
......
from functools import partial
from typing import List
from typing import List, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter
......@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
def build_filter_ensemble(
filter_name: str, components: List[List[str]]
filter_name: str, components: list[Union[list[dict], list[str]]]
) -> FilterEnsemble:
"""
Create a filtering pipeline.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment