"vscode:/vscode.git/clone" did not exist on "65fad189a8c81fa10ddcd4c3474c1bbcbd87bb35"
Commit 787aea5d authored by Baber's avatar Baber
Browse files

nit

parent fb63ac0f
...@@ -2,8 +2,6 @@ from dataclasses import dataclass ...@@ -2,8 +2,6 @@ from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional
from lm_eval.api.registry import get_aggregation, is_higher_better
@dataclass @dataclass
class MetricConfig: class MetricConfig:
...@@ -23,12 +21,16 @@ class MetricConfig: ...@@ -23,12 +21,16 @@ class MetricConfig:
@cached_property @cached_property
def aggregation(self) -> Callable: def aggregation(self) -> Callable:
from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None: if self.aggregation_fn is None:
return get_aggregation(self.name) return get_aggregation(self.name)
return self.aggregation_fn return self.aggregation_fn
@cached_property @cached_property
def _higher_is_better(self) -> bool: def _higher_is_better(self) -> bool:
from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None: if self.higher_is_better is None:
return is_higher_better(self.name) return is_higher_better(self.name)
return self.higher_is_better return self.higher_is_better
......
...@@ -4,17 +4,8 @@ from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union ...@@ -4,17 +4,8 @@ from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType from lm_eval.api.instance import OutputType
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation,
get_metric,
get_metric_aggregation,
is_higher_better,
)
from lm_eval.config.metric import MetricConfig from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize from lm_eval.config.utils import maybe_serialize
from lm_eval.filters import build_filter_ensemble
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -241,6 +232,15 @@ class TaskConfig(dict): ...@@ -241,6 +232,15 @@ class TaskConfig(dict):
@property @property
def get_metrics(self) -> list["MetricConfig"]: def get_metrics(self) -> list["MetricConfig"]:
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation,
get_metric,
get_metric_aggregation,
is_higher_better,
)
metrics = [] metrics = []
if self.metric_list is None: if self.metric_list is None:
# ---------- 1. If no metrics defined, use defaults for output type ---------- # ---------- 1. If no metrics defined, use defaults for output type ----------
...@@ -258,7 +258,7 @@ class TaskConfig(dict): ...@@ -258,7 +258,7 @@ class TaskConfig(dict):
for metric_name in _metric_list for metric_name in _metric_list
) )
else: else:
# ---------- 2. How will the samples be evaluated ---------- # ---------- 2. How will the outputs be evaluated ----------
for metric_config in self.metric_list: for metric_config in self.metric_list:
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
_metric_fn_kwargs = { _metric_fn_kwargs = {
...@@ -323,6 +323,8 @@ class TaskConfig(dict): ...@@ -323,6 +323,8 @@ class TaskConfig(dict):
@property @property
def get_filters(self) -> list["FilterEnsemble"]: def get_filters(self) -> list["FilterEnsemble"]:
from lm_eval.filters import build_filter_ensemble
if not self.filter_list: if not self.filter_list:
eval_logger.debug( eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats." "No custom filters defined; falling back to 'take_first' for handling repeats."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment