Commit 28c78d30 authored by Baber's avatar Baber
Browse files

add MetricConfig

parent de496b80
......@@ -485,6 +485,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if results is not None:
if args.log_samples:
samples = results.pop("samples")
# TODO: fix this!
results["higher_is_better"] = {
k: True for k, v in results["higher_is_better"].items()
}
dumped = json.dumps(
results, indent=2, default=handle_non_serializable, ensure_ascii=False
)
......
import abc
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, Callable, List, Optional, Union
......@@ -84,7 +83,7 @@ class GroupConfig(dict):
return str(value)
class ConfigurableGroup(abc.ABC):
class ConfigurableGroup:
def __init__(
self,
config: Optional[dict] = None,
......
......@@ -14,10 +14,23 @@ class Instance:
arguments: tuple
idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None)
default_factory=lambda: (None, None, None),
metadata=dict(
description="Metadata tuple containing task name, document ID, and number of repeats."
),
)
resps: list = field(
default_factory=list,
metadata=dict(
description="List of responses from the model for this instance."
),
)
filtered_resps: dict = field(
default_factory=dict,
metadata=dict(
description="List of filtered responses for this instance, keyed by filter name."
),
)
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
# initialized after init
task_name: Optional[str] = None
......@@ -29,7 +42,7 @@ class Instance:
self.task_name, self.doc_id, self.repeats = self.metadata
@property
def args(self):
def args(self) -> tuple:
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
......
......@@ -8,7 +8,6 @@ from collections.abc import Iterable
from typing import Callable, List, Optional, Sequence, TypeVar
import numpy as np
import sacrebleu
from lm_eval.api.registry import register_aggregation, register_metric
......@@ -92,6 +91,8 @@ def bleu(items):
Higher is better
"""
import sacrebleu
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
......@@ -107,6 +108,8 @@ def chrf(items):
Higher is better # TODO I think
"""
import sacrebleu
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
......@@ -123,6 +126,8 @@ def ter(items):
Lower is better
"""
import sacrebleu
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
......
import logging
from typing import Callable, Dict, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
import evaluate as hf_evaluate
from lm_eval.api.model import LM
if TYPE_CHECKING:
from lm_eval.api.model import LM
eval_logger = logging.getLogger(__name__)
......@@ -12,6 +11,8 @@ MODEL_REGISTRY = {}
def register_model(*names):
from lm_eval.api.model import LM
# either pass a list or a single alias.
# function receives them as a tuple of strings
......@@ -31,7 +32,7 @@ def register_model(*names):
return decorate
def get_model(model_name):
def get_model(model_name: str) -> type["LM"]:
try:
return MODEL_REGISTRY[model_name]
except KeyError:
......@@ -46,7 +47,7 @@ ALL_TASKS = set()
func2task_index = {}
def register_task(name):
def register_task(name: str):
def decorate(fn):
assert name not in TASK_REGISTRY, (
f"task named '{name}' conflicts with existing registered task!"
......@@ -120,7 +121,7 @@ def register_metric(**args):
return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
def get_metric(name: str, hf_evaluate_metric=False) -> Optional[Callable]:
if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
......@@ -130,6 +131,8 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
)
try:
import evaluate as hf_evaluate
metric_object = hf_evaluate.load(name)
return metric_object.compute
except Exception:
......@@ -150,21 +153,21 @@ def register_aggregation(name: str):
return decorate
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
def get_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name) -> bool:
def is_higher_better(metric_name) -> Optional[bool]:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
......
import logging
import warnings
from functools import partial
from typing import TYPE_CHECKING, Iterable, Optional, Union
from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
import datasets
......@@ -181,7 +181,7 @@ class ContextSampler:
return chat_history
def sample(self, n: int):
def sample(self, n: int) -> Sequence[dict]:
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
......@@ -190,7 +190,7 @@ class ContextSampler:
class FirstNSampler(ContextSampler):
def sample(self, n: int) -> None:
def sample(self, n: int) -> Sequence[dict]:
"""
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
......
This diff is collapsed.
......@@ -287,7 +287,7 @@ def simple_evaluate(
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
def _adjust_config(task_dict):
def _adjust_config(task_dict: dict[str, "Task"]) -> dict[str, "Task"]:
adjusted_task_dict = {}
for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict):
......
......@@ -12,7 +12,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr,
stderr_for_metric,
)
from lm_eval.api.task import Task
from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.utils import positional_deprecated
......@@ -58,7 +58,7 @@ class TaskOutput:
group_alias=None,
is_group=None,
):
self.task = task
self.task: Union[Task, ConfigurableTask] = task
self.task_config = task_config
self.task_name = task_name
self.group_name = group_name
......
from functools import partial
from typing import List
from typing import List, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter
......@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
def build_filter_ensemble(
filter_name: str, components: List[List[str]]
filter_name: str, components: list[Union[list[dict], list[str]]]
) -> FilterEnsemble:
"""
Create a filtering pipeline.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment