Commit 28c78d30 authored by Baber's avatar Baber
Browse files

add MetricConfig

parent de496b80
...@@ -485,6 +485,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -485,6 +485,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if results is not None: if results is not None:
if args.log_samples: if args.log_samples:
samples = results.pop("samples") samples = results.pop("samples")
# TODO: fix this!
results["higher_is_better"] = {
k: True for k, v in results["higher_is_better"].items()
}
dumped = json.dumps( dumped = json.dumps(
results, indent=2, default=handle_non_serializable, ensure_ascii=False results, indent=2, default=handle_non_serializable, ensure_ascii=False
) )
......
import abc
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
...@@ -84,7 +83,7 @@ class GroupConfig(dict): ...@@ -84,7 +83,7 @@ class GroupConfig(dict):
return str(value) return str(value)
class ConfigurableGroup(abc.ABC): class ConfigurableGroup:
def __init__( def __init__(
self, self,
config: Optional[dict] = None, config: Optional[dict] = None,
......
...@@ -14,10 +14,23 @@ class Instance: ...@@ -14,10 +14,23 @@ class Instance:
arguments: tuple arguments: tuple
idx: int idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field( metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None) default_factory=lambda: (None, None, None),
metadata=dict(
description="Metadata tuple containing task name, document ID, and number of repeats."
),
)
resps: list = field(
default_factory=list,
metadata=dict(
description="List of responses from the model for this instance."
),
)
filtered_resps: dict = field(
default_factory=dict,
metadata=dict(
description="List of filtered responses for this instance, keyed by filter name."
),
) )
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
# initialized after init # initialized after init
task_name: Optional[str] = None task_name: Optional[str] = None
...@@ -29,7 +42,7 @@ class Instance: ...@@ -29,7 +42,7 @@ class Instance:
self.task_name, self.doc_id, self.repeats = self.metadata self.task_name, self.doc_id, self.repeats = self.metadata
@property @property
def args(self): def args(self) -> tuple:
""" """
Returns (string,) where `string` is the string to calculate loglikelihood over Returns (string,) where `string` is the string to calculate loglikelihood over
""" """
......
...@@ -8,7 +8,6 @@ from collections.abc import Iterable ...@@ -8,7 +8,6 @@ from collections.abc import Iterable
from typing import Callable, List, Optional, Sequence, TypeVar from typing import Callable, List, Optional, Sequence, TypeVar
import numpy as np import numpy as np
import sacrebleu
from lm_eval.api.registry import register_aggregation, register_metric from lm_eval.api.registry import register_aggregation, register_metric
...@@ -92,6 +91,8 @@ def bleu(items): ...@@ -92,6 +91,8 @@ def bleu(items):
Higher is better Higher is better
""" """
import sacrebleu
refs = list(zip(*items))[0] refs = list(zip(*items))[0]
preds = list(zip(*items))[1] preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds) refs, preds = _sacreformat(refs, preds)
...@@ -107,6 +108,8 @@ def chrf(items): ...@@ -107,6 +108,8 @@ def chrf(items):
Higher is better # TODO I think Higher is better # TODO I think
""" """
import sacrebleu
refs = list(zip(*items))[0] refs = list(zip(*items))[0]
preds = list(zip(*items))[1] preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds) refs, preds = _sacreformat(refs, preds)
...@@ -123,6 +126,8 @@ def ter(items): ...@@ -123,6 +126,8 @@ def ter(items):
Lower is better Lower is better
""" """
import sacrebleu
refs = list(zip(*items))[0] refs = list(zip(*items))[0]
preds = list(zip(*items))[1] preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds) refs, preds = _sacreformat(refs, preds)
......
import logging import logging
from typing import Callable, Dict, Union from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
import evaluate as hf_evaluate
from lm_eval.api.model import LM
if TYPE_CHECKING:
from lm_eval.api.model import LM
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -12,6 +11,8 @@ MODEL_REGISTRY = {} ...@@ -12,6 +11,8 @@ MODEL_REGISTRY = {}
def register_model(*names): def register_model(*names):
from lm_eval.api.model import LM
# either pass a list or a single alias. # either pass a list or a single alias.
# function receives them as a tuple of strings # function receives them as a tuple of strings
...@@ -31,7 +32,7 @@ def register_model(*names): ...@@ -31,7 +32,7 @@ def register_model(*names):
return decorate return decorate
def get_model(model_name): def get_model(model_name: str) -> type["LM"]:
try: try:
return MODEL_REGISTRY[model_name] return MODEL_REGISTRY[model_name]
except KeyError: except KeyError:
...@@ -46,7 +47,7 @@ ALL_TASKS = set() ...@@ -46,7 +47,7 @@ ALL_TASKS = set()
func2task_index = {} func2task_index = {}
def register_task(name): def register_task(name: str):
def decorate(fn): def decorate(fn):
assert name not in TASK_REGISTRY, ( assert name not in TASK_REGISTRY, (
f"task named '{name}' conflicts with existing registered task!" f"task named '{name}' conflicts with existing registered task!"
...@@ -120,7 +121,7 @@ def register_metric(**args): ...@@ -120,7 +121,7 @@ def register_metric(**args):
return decorate return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Callable: def get_metric(name: str, hf_evaluate_metric=False) -> Optional[Callable]:
if not hf_evaluate_metric: if not hf_evaluate_metric:
if name in METRIC_REGISTRY: if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name] return METRIC_REGISTRY[name]
...@@ -130,6 +131,8 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable: ...@@ -130,6 +131,8 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
) )
try: try:
import evaluate as hf_evaluate
metric_object = hf_evaluate.load(name) metric_object = hf_evaluate.load(name)
return metric_object.compute return metric_object.compute
except Exception: except Exception:
...@@ -150,21 +153,21 @@ def register_aggregation(name: str): ...@@ -150,21 +153,21 @@ def register_aggregation(name: str):
return decorate return decorate
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: def get_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
try: try:
return AGGREGATION_REGISTRY[name] return AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!") eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
try: try:
return METRIC_AGGREGATION_REGISTRY[name] return METRIC_AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!") eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name) -> bool: def is_higher_better(metric_name) -> Optional[bool]:
try: try:
return HIGHER_IS_BETTER_REGISTRY[metric_name] return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError: except KeyError:
......
import logging import logging
import warnings import warnings
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Iterable, Optional, Union from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
import datasets import datasets
...@@ -181,7 +181,7 @@ class ContextSampler: ...@@ -181,7 +181,7 @@ class ContextSampler:
return chat_history return chat_history
def sample(self, n: int): def sample(self, n: int) -> Sequence[dict]:
""" """
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
""" """
...@@ -190,7 +190,7 @@ class ContextSampler: ...@@ -190,7 +190,7 @@ class ContextSampler:
class FirstNSampler(ContextSampler): class FirstNSampler(ContextSampler):
def sample(self, n: int) -> None: def sample(self, n: int) -> Sequence[dict]:
""" """
Draw the first `n` samples in order from the specified split. Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
......
This diff is collapsed.
...@@ -287,7 +287,7 @@ def simple_evaluate( ...@@ -287,7 +287,7 @@ def simple_evaluate(
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups. # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed) # (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
def _adjust_config(task_dict): def _adjust_config(task_dict: dict[str, "Task"]) -> dict[str, "Task"]:
adjusted_task_dict = {} adjusted_task_dict = {}
for task_name, task_obj in task_dict.items(): for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict): if isinstance(task_obj, dict):
......
...@@ -12,7 +12,7 @@ from lm_eval.api.metrics import ( ...@@ -12,7 +12,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr, pooled_sample_stderr,
stderr_for_metric, stderr_for_metric,
) )
from lm_eval.api.task import Task from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.utils import positional_deprecated from lm_eval.utils import positional_deprecated
...@@ -58,7 +58,7 @@ class TaskOutput: ...@@ -58,7 +58,7 @@ class TaskOutput:
group_alias=None, group_alias=None,
is_group=None, is_group=None,
): ):
self.task = task self.task: Union[Task, ConfigurableTask] = task
self.task_config = task_config self.task_config = task_config
self.task_name = task_name self.task_name = task_name
self.group_name = group_name self.group_name = group_name
......
from functools import partial from functools import partial
from typing import List from typing import List, Union
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter from lm_eval.api.registry import get_filter
...@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation ...@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
def build_filter_ensemble( def build_filter_ensemble(
filter_name: str, components: List[List[str]] filter_name: str, components: list[Union[list[dict], list[str]]]
) -> FilterEnsemble: ) -> FilterEnsemble:
""" """
Create a filtering pipeline. Create a filtering pipeline.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment