Commit 5b527a71 authored by lintangsutawika's avatar lintangsutawika
Browse files

update format

parent 83c070d4
......@@ -5,7 +5,7 @@ import random
import re
from collections.abc import Callable
from copy import deepcopy
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import (
Any,
......@@ -51,18 +51,20 @@ ALL_OUTPUT_TYPES = [
eval_logger = logging.getLogger("lm-eval")
@dataclass
class AggMetricConfig(dict):
metric: Optional[str] = "acc"
metric_alias: Optional[str] = "acc"
metric_alias: Optional[str] = None
aggregation: Optional[str] = "mean"
weight_by_size: Optional[str] = False
filter_list: Optional[Union[str,list]] = "none"
filter_list: Optional[Union[str, list]] = "none"
def __post_init__(self):
if isinstance(self.filter_list, str):
self.filter_list = [self.filter_list]
@dataclass
class GroupConfig(dict):
group: Optional[str] = None
......@@ -83,13 +85,13 @@ class GroupConfig(dict):
return setattr(self, item, value)
def __post_init__(self):
if self.aggregate_metric_list is not None:
if isinstance(self.aggregate_metric_list, dict):
self.aggregate_metric_list = [self.aggregate_metric_list]
if self.aggregate_metric is not None:
if isinstance(self.aggregate_metric, dict):
self.aggregate_metric = [self.aggregate_metric]
self.aggregate_metric_list = [
self.aggregate_metric = [
AggMetricConfig(**item) if isinstance(item, dict) else item
for item in self.aggregate_metric_list
for item in self.aggregate_metric
]
def to_dict(self, keep_callable: bool = False) -> dict:
......
......@@ -661,14 +661,14 @@ def evaluate(
for metric_config in agg_metric_list:
for filter in metric_config["filter_list"]:
if metric != ",".join([metric_config["metric"], filter]):
if metric != ",".join(
[metric_config["metric"], filter]
):
continue
# compute group's pooled metric and stderr
if metric_config["aggregation"] == "mean":
aggregate_fn = (
lm_eval.api.metrics.aggregate_subtask_metrics
)
aggregate_fn = lm_eval.api.metrics.aggregate_subtask_metrics
else:
aggregate_fn = metric_config["aggregation"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment