Commit 1b5c6f88 authored by Baber's avatar Baber
Browse files

add MetricConfig

parent 6b3f3f7e
import abc
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
...@@ -84,7 +83,7 @@ class GroupConfig(dict): ...@@ -84,7 +83,7 @@ class GroupConfig(dict):
return str(value) return str(value)
class ConfigurableGroup(abc.ABC): class ConfigurableGroup:
def __init__( def __init__(
self, self,
config: Optional[dict] = None, config: Optional[dict] = None,
......
...@@ -14,10 +14,23 @@ class Instance: ...@@ -14,10 +14,23 @@ class Instance:
arguments: tuple arguments: tuple
idx: int idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field( metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None) default_factory=lambda: (None, None, None),
metadata=dict(
description="Metadata tuple containing task name, document ID, and number of repeats."
),
)
resps: list = field(
default_factory=list,
metadata=dict(
description="List of responses from the model for this instance."
),
)
filtered_resps: dict = field(
default_factory=dict,
metadata=dict(
description="List of filtered responses for this instance, keyed by filter name."
),
) )
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
# initialized after init # initialized after init
task_name: Optional[str] = None task_name: Optional[str] = None
...@@ -29,7 +42,7 @@ class Instance: ...@@ -29,7 +42,7 @@ class Instance:
self.task_name, self.doc_id, self.repeats = self.metadata self.task_name, self.doc_id, self.repeats = self.metadata
@property @property
def args(self): def args(self) -> tuple:
""" """
Returns (string,) where `string` is the string to calculate loglikelihood over Returns (string,) where `string` is the string to calculate loglikelihood over
""" """
......
...@@ -7,7 +7,6 @@ from collections.abc import Iterable ...@@ -7,7 +7,6 @@ from collections.abc import Iterable
from typing import List from typing import List
import numpy as np import numpy as np
import sacrebleu
from lm_eval.api.registry import register_aggregation, register_metric from lm_eval.api.registry import register_aggregation, register_metric
...@@ -89,6 +88,8 @@ def bleu(items): ...@@ -89,6 +88,8 @@ def bleu(items):
Higher is better Higher is better
""" """
import sacrebleu
refs = list(zip(*items))[0] refs = list(zip(*items))[0]
preds = list(zip(*items))[1] preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds) refs, preds = _sacreformat(refs, preds)
...@@ -104,6 +105,8 @@ def chrf(items): ...@@ -104,6 +105,8 @@ def chrf(items):
Higher is better # TODO I think Higher is better # TODO I think
""" """
import sacrebleu
refs = list(zip(*items))[0] refs = list(zip(*items))[0]
preds = list(zip(*items))[1] preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds) refs, preds = _sacreformat(refs, preds)
...@@ -120,6 +123,8 @@ def ter(items): ...@@ -120,6 +123,8 @@ def ter(items):
Lower is better Lower is better
""" """
import sacrebleu
refs = list(zip(*items))[0] refs = list(zip(*items))[0]
preds = list(zip(*items))[1] preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds) refs, preds = _sacreformat(refs, preds)
......
import logging import logging
from typing import Callable, Dict, Union from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
import evaluate as hf_evaluate
from lm_eval.api.model import LM
if TYPE_CHECKING:
from lm_eval.api.model import LM
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -12,6 +11,8 @@ MODEL_REGISTRY = {} ...@@ -12,6 +11,8 @@ MODEL_REGISTRY = {}
def register_model(*names): def register_model(*names):
from lm_eval.api.model import LM
# either pass a list or a single alias. # either pass a list or a single alias.
# function receives them as a tuple of strings # function receives them as a tuple of strings
...@@ -31,7 +32,7 @@ def register_model(*names): ...@@ -31,7 +32,7 @@ def register_model(*names):
return decorate return decorate
def get_model(model_name): def get_model(model_name: str) -> type["LM"]:
try: try:
return MODEL_REGISTRY[model_name] return MODEL_REGISTRY[model_name]
except KeyError: except KeyError:
...@@ -46,7 +47,7 @@ ALL_TASKS = set() ...@@ -46,7 +47,7 @@ ALL_TASKS = set()
func2task_index = {} func2task_index = {}
def register_task(name): def register_task(name: str):
def decorate(fn): def decorate(fn):
assert name not in TASK_REGISTRY, ( assert name not in TASK_REGISTRY, (
f"task named '{name}' conflicts with existing registered task!" f"task named '{name}' conflicts with existing registered task!"
...@@ -120,7 +121,7 @@ def register_metric(**args): ...@@ -120,7 +121,7 @@ def register_metric(**args):
return decorate return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Callable: def get_metric(name: str, hf_evaluate_metric=False) -> Optional[Callable]:
if not hf_evaluate_metric: if not hf_evaluate_metric:
if name in METRIC_REGISTRY: if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name] return METRIC_REGISTRY[name]
...@@ -130,6 +131,8 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable: ...@@ -130,6 +131,8 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
) )
try: try:
import evaluate as hf_evaluate
metric_object = hf_evaluate.load(name) metric_object = hf_evaluate.load(name)
return metric_object.compute return metric_object.compute
except Exception: except Exception:
...@@ -150,21 +153,21 @@ def register_aggregation(name: str): ...@@ -150,21 +153,21 @@ def register_aggregation(name: str):
return decorate return decorate
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: def get_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
try: try:
return AGGREGATION_REGISTRY[name] return AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!") eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
try: try:
return METRIC_AGGREGATION_REGISTRY[name] return METRIC_AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!") eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name) -> bool: def is_higher_better(metric_name) -> Optional[bool]:
try: try:
return HIGHER_IS_BETTER_REGISTRY[metric_name] return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError: except KeyError:
......
import logging import logging
import warnings import warnings
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Iterable, Optional, Union from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
import datasets import datasets
...@@ -181,7 +181,7 @@ class ContextSampler: ...@@ -181,7 +181,7 @@ class ContextSampler:
return chat_history return chat_history
def sample(self, n: int): def sample(self, n: int) -> Sequence[dict]:
""" """
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
""" """
...@@ -190,7 +190,7 @@ class ContextSampler: ...@@ -190,7 +190,7 @@ class ContextSampler:
class FirstNSampler(ContextSampler): class FirstNSampler(ContextSampler):
def sample(self, n: int) -> None: def sample(self, n: int) -> Sequence[dict]:
""" """
Draw the first `n` samples in order from the specified split. Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
......
...@@ -6,6 +6,7 @@ import re ...@@ -6,6 +6,7 @@ import re
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from functools import cached_property
from inspect import getsource from inspect import getsource
from typing import ( from typing import (
Any, Any,
...@@ -23,6 +24,7 @@ from typing import ( ...@@ -23,6 +24,7 @@ from typing import (
import datasets import datasets
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from typing_extensions import deprecated
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
...@@ -51,6 +53,43 @@ ALL_OUTPUT_TYPES = [ ...@@ -51,6 +53,43 @@ ALL_OUTPUT_TYPES = [
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
@dataclass
class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
aggregation_fn: Optional[Callable] = None
higher_is_better: bool = True
hf_evaluate: bool = False
@cached_property
def metric_names(self) -> str:
return self.name
@cached_property
def aggregation(self) -> Callable:
if self.aggregation_fn is None:
return get_aggregation(self.name)
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool:
if self.higher_is_better is None:
return is_higher_better(self.name)
return self.higher_is_better
@dataclass
class FilterConfig:
"""Encapsulates information about a filter."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
...@@ -133,6 +172,93 @@ class TaskConfig(dict): ...@@ -133,6 +172,93 @@ class TaskConfig(dict):
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}" f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
) )
if self.metric_list is not None:
for metric_config in self.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
def get_metrics(self) -> list["MetricConfig"]:
metrics = []
if self.metric_list is None:
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
metrics.extend(
MetricConfig(
name=metric_name,
fn=get_metric(metric_name),
aggregation_fn=get_metric_aggregation(metric_name),
higher_is_better=is_higher_better(metric_name),
)
for metric_name in _metric_list
)
else:
for metric_config in self.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
_hf_evaluate_metric: bool = metric_config.get("hf_evaluate", False)
_metric_fn = None
_aggregation = None
if self.process_results is not None:
# User will compute metrics inside `process_results()`
_metric_name = None
_metric_fn_kwargs = {}
elif callable(metric_name):
# User passed a function object
_metric_name = metric_name.__name__
_metric_fn = metric_name.__call__
else:
# Normal: look up by name
_metric_name = get_metric(metric_name, _hf_evaluate_metric)
# ---------- 3. Decide how to aggregate examples ----------
if "aggregation" in metric_config:
if isinstance(_agg_name := metric_config["aggregation"], str):
_aggregation = get_aggregation(_agg_name)
elif callable(_agg_name): # noqa: E721
_aggregation = metric_config["aggregation"]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
_aggregation = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[_aggregation]}"
)
# ---------- 4. Determine “higher-is-better” semantics ----------
if "higher_is_better" in metric_config:
_higher_is_better = metric_config["higher_is_better"]
else:
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
_higher_is_better = is_higher_better(metric_name)
metrics.append(
MetricConfig(
name=_metric_name,
fn=_metric_fn,
kwargs=_metric_fn_kwargs,
aggregation_fn=_aggregation,
higher_is_better=_higher_is_better,
hf_evaluate=_hf_evaluate_metric,
)
)
return metrics
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
...@@ -534,7 +660,7 @@ class Task(abc.ABC): ...@@ -534,7 +660,7 @@ class Task(abc.ABC):
""" """
pass pass
@abc.abstractmethod @deprecated("not used anymore")
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [metric_score] -> float} :returns: {str: [metric_score] -> float}
...@@ -543,7 +669,7 @@ class Task(abc.ABC): ...@@ -543,7 +669,7 @@ class Task(abc.ABC):
""" """
pass pass
@abc.abstractmethod @deprecated("not used anymore")
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
...@@ -661,23 +787,13 @@ class Task(abc.ABC): ...@@ -661,23 +787,13 @@ class Task(abc.ABC):
Parameters: Parameters:
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics. - metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
""" """
( # if not isinstance(self, ConfigurableTask):
self._metric_fn_list, # self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
self._aggregation_list, # self.aggregation = lambda: {
self._metric_fn_kwargs, # metric_name: get_metric_aggregation(metric_name)
self._higher_is_better, # }
) = ({}, {}, {}, {}) setattr(self._config, "metric_list", [MetricConfig(name=metric_name)])
self._metric_fn_list[metric_name] = get_metric(metric_name) setattr(self._config, "process_results", lambda *args: {"bypass": 0})
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self._metric_fn_kwargs[metric_name] = {}
if not isinstance(self, ConfigurableTask):
self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
self.aggregation = lambda: {
metric_name: get_metric_aggregation(metric_name)
}
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)
def set_fewshot_seed(self, seed: Optional[int] = None) -> None: def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
self.fewshot_rnd = random.Random(seed) self.fewshot_rnd = random.Random(seed)
...@@ -739,7 +855,7 @@ class ConfigurableTask(Task): ...@@ -739,7 +855,7 @@ class ConfigurableTask(Task):
cache_dir=None, cache_dir=None,
download_mode=None, download_mode=None,
config: Optional[dict] = None, config: Optional[dict] = None,
) -> None: # TODO no super() call here ) -> None:
# Get pre-configured attributes # Get pre-configured attributes
self._config = self.CONFIG self._config = self.CONFIG
...@@ -784,83 +900,7 @@ class ConfigurableTask(Task): ...@@ -784,83 +900,7 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None: if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name self.DATASET_NAME = self.config.dataset_name
self._metric_fn_list = {} self.metric_list: list[MetricConfig] = self._config.get_metrics()
self._metric_fn_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {}
if self.config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ?
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = get_metric_aggregation(
metric_name
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
else:
for metric_config in self.config.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"]
kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
hf_evaluate_metric = (
"hf_evaluate" in metric_config
and metric_config["hf_evaluate"] is True
)
if self.config.process_results is not None:
self._metric_fn_list[metric_name] = None
self._metric_fn_kwargs[metric_name] = {}
elif callable(metric_name):
metric_fn = metric_name.__call__
metric_name = metric_name.__name__
self._metric_fn_list[metric_name] = metric_fn
self._metric_fn_kwargs[metric_name] = kwargs
else:
self._metric_fn_list[metric_name] = get_metric(
metric_name, hf_evaluate_metric
)
self._metric_fn_kwargs[metric_name] = kwargs
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name): # noqa: E721
self._aggregation_list[metric_name] = metric_config[
"aggregation"
]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
)
self._aggregation_list[metric_name] = metric_agg
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
]
else:
eval_logger.warning(
f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self.download(self.config.dataset_kwargs) self.download(self.config.dataset_kwargs)
self._training_docs = None self._training_docs = None
...@@ -868,17 +908,23 @@ class ConfigurableTask(Task): ...@@ -868,17 +908,23 @@ class ConfigurableTask(Task):
if self.config.filter_list is not None: if self.config.filter_list is not None:
self._filters = [] self._filters = []
for filter_config in self.config.filter_list: if isinstance(self.config.filter_list, dict):
filter_name = filter_config["name"] for filter_config in self.config.filter_list:
filter_functions = filter_config["filter"] self._filters.append(
components = [] build_filter_ensemble(
for function in filter_functions: filter_config["name"],
kwargs = { [
key: function[key] for key in function if key != "function" [
} {
components.append([function["function"], kwargs]) key: function[key]
filter_pipeline = build_filter_ensemble(filter_name, components) for key in function
self._filters.append(filter_pipeline) if key != "function"
}
]
for function in filter_config["filter"]
],
)
)
else: else:
# TODO: handle repeats in a more general way rather than just discarding # TODO: handle repeats in a more general way rather than just discarding
eval_logger.debug( eval_logger.debug(
...@@ -1476,7 +1522,7 @@ class ConfigurableTask(Task): ...@@ -1476,7 +1522,7 @@ class ConfigurableTask(Task):
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
# TODO: we should raise a warning telling users this will at most ~2x runtime. # TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys(): if "acc_mutual_info" in [m.metric_names for m in self.metric_list]:
# if we are calculating multiple choice accuracy # if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls. # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
...@@ -1543,7 +1589,7 @@ class ConfigurableTask(Task): ...@@ -1543,7 +1589,7 @@ class ConfigurableTask(Task):
return self.config.process_results(doc, results) return self.config.process_results(doc, results)
result_dict = {} result_dict = {}
use_metric = list(self._metric_fn_list.keys()) use_metric = list(m.metric_names for m in self.metric_list)
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
results = results[0] results = results[0]
ll, is_greedy = results ll, is_greedy = results
...@@ -1579,10 +1625,7 @@ class ConfigurableTask(Task): ...@@ -1579,10 +1625,7 @@ class ConfigurableTask(Task):
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices]) completion_len = np.array([float(len(i)) for i in choices])
if ( if 2 * len(choices) == len(lls) and "acc_mutual_info" in use_metric:
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys()
):
# then we are doing mutual info. # then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods # this stores the "dryrun" / unconditional answer loglikelihoods
# as we extend the args list with unconditional ("", continuation) pairs # as we extend the args list with unconditional ("", continuation) pairs
...@@ -1667,12 +1710,12 @@ class ConfigurableTask(Task): ...@@ -1667,12 +1710,12 @@ class ConfigurableTask(Task):
gold = list(gold) gold = list(gold)
# TODO: handle this better # TODO: handle this better
elif type(gold) is not type(result) and not ( elif type(gold) is not type(result) and not (
"bypass" in self._metric_fn_list.keys() or isinstance(result, list) "bypass" in use_metric or isinstance(result, list)
): ):
# cast gold to the same type as result # cast gold to the same type as result
gold = type(result)(gold) gold = type(result)(gold)
for metric in self._metric_fn_list.keys(): for metric in self.metric_list:
if self.multiple_target: if self.multiple_target:
# in the case where we have multiple targets, # in the case where we have multiple targets,
# return true if any are true # return true if any are true
...@@ -1682,28 +1725,26 @@ class ConfigurableTask(Task): ...@@ -1682,28 +1725,26 @@ class ConfigurableTask(Task):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold) # print(gold)
gold = [gold] gold = [gold]
if metric == "exact_match": if metric.name == "exact_match":
result = [result for _ in range(len(gold))] result = [result for _ in range(len(gold))]
scores = self._metric_fn_list[metric]( scores = metric.fn(
references=gold, references=gold,
predictions=result, predictions=result,
**self._metric_fn_kwargs[metric], **metric.kwargs,
)[metric] )[metric]
result_score = 1.0 if scores > 0.0 else 0.0 result_score = 1.0 if scores > 0.0 else 0.0
else: else:
for gold_option in gold: for gold_option in gold:
try: try:
result_score = self._metric_fn_list[metric]( result_score = metric.fn(
references=[gold_option], references=[gold_option],
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[metric], **metric.kwargs,
) )
except ( except (
TypeError TypeError
): # TODO: this is hacky and I don't want to do it ): # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric]( result_score = metric.fn([gold_option, result])
[gold_option, result]
)
if isinstance(result_score, dict): if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict. # TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric] result_score = result_score[metric]
...@@ -1714,13 +1755,13 @@ class ConfigurableTask(Task): ...@@ -1714,13 +1755,13 @@ class ConfigurableTask(Task):
result_score = 0.0 result_score = 0.0
else: else:
try: try:
result_score = self._metric_fn_list[metric]( result_score = metric.fn(
references=[gold], references=[gold],
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[metric], **metric.kwargs,
) )
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result]) result_score = metric.fn([gold, result])
if isinstance(result_score, dict): if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict. # TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function # This allows for multiple metrics to be returned from the same function
...@@ -1737,10 +1778,10 @@ class ConfigurableTask(Task): ...@@ -1737,10 +1778,10 @@ class ConfigurableTask(Task):
return result_dict return result_dict
def aggregation(self) -> dict: def aggregation(self) -> dict:
return self._aggregation_list return {k.name: k.aggregation_fn for k in self.metric_list}
def higher_is_better(self) -> dict: def higher_is_better(self) -> dict:
return self._higher_is_better return {k.name: k.higher_is_better for k in self.metric_list}
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
return getattr(self._config, key, None) return getattr(self._config, key, None)
......
...@@ -272,7 +272,7 @@ def simple_evaluate( ...@@ -272,7 +272,7 @@ def simple_evaluate(
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups. # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed) # (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
def _adjust_config(task_dict): def _adjust_config(task_dict: dict[str, "Task"]) -> dict[str, "Task"]:
adjusted_task_dict = {} adjusted_task_dict = {}
for task_name, task_obj in task_dict.items(): for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict): if isinstance(task_obj, dict):
......
...@@ -12,7 +12,7 @@ from lm_eval.api.metrics import ( ...@@ -12,7 +12,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr, pooled_sample_stderr,
stderr_for_metric, stderr_for_metric,
) )
from lm_eval.api.task import Task from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.utils import positional_deprecated from lm_eval.utils import positional_deprecated
...@@ -58,7 +58,7 @@ class TaskOutput: ...@@ -58,7 +58,7 @@ class TaskOutput:
group_alias=None, group_alias=None,
is_group=None, is_group=None,
): ):
self.task = task self.task: Union[Task, ConfigurableTask] = task
self.task_config = task_config self.task_config = task_config
self.task_name = task_name self.task_name = task_name
self.group_name = group_name self.group_name = group_name
......
from functools import partial from functools import partial
from typing import List from typing import List, Union
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter from lm_eval.api.registry import get_filter
...@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation ...@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
def build_filter_ensemble( def build_filter_ensemble(
filter_name: str, components: List[List[str]] filter_name: str, components: list[Union[list[dict], list[str]]]
) -> FilterEnsemble: ) -> FilterEnsemble:
""" """
Create a filtering pipeline. Create a filtering pipeline.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment