Commit 5b527a71 authored by lintangsutawika's avatar lintangsutawika
Browse files

update format

parent 83c070d4
...@@ -5,7 +5,7 @@ import random ...@@ -5,7 +5,7 @@ import random
import re import re
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import ( from typing import (
Any, Any,
...@@ -51,18 +51,20 @@ ALL_OUTPUT_TYPES = [ ...@@ -51,18 +51,20 @@ ALL_OUTPUT_TYPES = [
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
@dataclass @dataclass
class AggMetricConfig(dict): class AggMetricConfig(dict):
metric: Optional[str] = "acc" metric: Optional[str] = "acc"
metric_alias: Optional[str] = "acc" metric_alias: Optional[str] = None
aggregation: Optional[str] = "mean" aggregation: Optional[str] = "mean"
weight_by_size: Optional[str] = False weight_by_size: Optional[str] = False
filter_list: Optional[Union[str,list]] = "none" filter_list: Optional[Union[str, list]] = "none"
def __post_init__(self): def __post_init__(self):
if isinstance(self.filter_list, str): if isinstance(self.filter_list, str):
self.filter_list = [self.filter_list] self.filter_list = [self.filter_list]
@dataclass @dataclass
class GroupConfig(dict): class GroupConfig(dict):
group: Optional[str] = None group: Optional[str] = None
...@@ -83,13 +85,13 @@ class GroupConfig(dict): ...@@ -83,13 +85,13 @@ class GroupConfig(dict):
return setattr(self, item, value) return setattr(self, item, value)
def __post_init__(self): def __post_init__(self):
if self.aggregate_metric_list is not None: if self.aggregate_metric is not None:
if isinstance(self.aggregate_metric_list, dict): if isinstance(self.aggregate_metric, dict):
self.aggregate_metric_list = [self.aggregate_metric_list] self.aggregate_metric = [self.aggregate_metric]
self.aggregate_metric_list = [ self.aggregate_metric = [
AggMetricConfig(**item) if isinstance(item, dict) else item AggMetricConfig(**item) if isinstance(item, dict) else item
for item in self.aggregate_metric_list for item in self.aggregate_metric
] ]
def to_dict(self, keep_callable: bool = False) -> dict: def to_dict(self, keep_callable: bool = False) -> dict:
......
...@@ -661,14 +661,14 @@ def evaluate( ...@@ -661,14 +661,14 @@ def evaluate(
for metric_config in agg_metric_list: for metric_config in agg_metric_list:
for filter in metric_config["filter_list"]: for filter in metric_config["filter_list"]:
if metric != ",".join([metric_config["metric"], filter]): if metric != ",".join(
[metric_config["metric"], filter]
):
continue continue
# compute group's pooled metric and stderr # compute group's pooled metric and stderr
if metric_config["aggregation"] == "mean": if metric_config["aggregation"] == "mean":
aggregate_fn = ( aggregate_fn = lm_eval.api.metrics.aggregate_subtask_metrics
lm_eval.api.metrics.aggregate_subtask_metrics
)
else: else:
aggregate_fn = metric_config["aggregation"] aggregate_fn = metric_config["aggregation"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment