Commit 787aea5d authored by Baber's avatar Baber
Browse files

nit

parent fb63ac0f
......@@ -2,8 +2,6 @@ from dataclasses import dataclass
from functools import cached_property
from typing import Any, Callable, List, Optional
from lm_eval.api.registry import get_aggregation, is_higher_better
@dataclass
class MetricConfig:
......@@ -23,12 +21,16 @@ class MetricConfig:
@cached_property
def aggregation(self) -> Callable:
from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None:
return get_aggregation(self.name)
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool:
from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None:
return is_higher_better(self.name)
return self.higher_is_better
......
......@@ -4,17 +4,8 @@ from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation,
get_metric,
get_metric_aggregation,
is_higher_better,
)
from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize
from lm_eval.filters import build_filter_ensemble
if TYPE_CHECKING:
......@@ -241,6 +232,15 @@ class TaskConfig(dict):
@property
def get_metrics(self) -> list["MetricConfig"]:
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation,
get_metric,
get_metric_aggregation,
is_higher_better,
)
metrics = []
if self.metric_list is None:
# ---------- 1. If no metrics defined, use defaults for output type ----------
......@@ -258,7 +258,7 @@ class TaskConfig(dict):
for metric_name in _metric_list
)
else:
# ---------- 2. How will the samples be evaluated ----------
# ---------- 2. How will the outputs be evaluated ----------
for metric_config in self.metric_list:
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
......@@ -323,6 +323,8 @@ class TaskConfig(dict):
@property
def get_filters(self) -> list["FilterEnsemble"]:
from lm_eval.filters import build_filter_ensemble
if not self.filter_list:
eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment