Commit 80d0f412 authored by lintangsutawika's avatar lintangsutawika
Browse files

change how aggregate_metric is loaded

parent 0f095f79
......@@ -5,7 +5,7 @@ import random
import re
from collections.abc import Callable
from copy import deepcopy
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from inspect import getsource
from typing import (
Any,
......@@ -51,6 +51,17 @@ ALL_OUTPUT_TYPES = [
eval_logger = logging.getLogger("lm-eval")
@dataclass
class AggMetricConfig(dict):
metric: Optional[str] = "acc"
metric_alias: Optional[str] = "acc"
aggregation: Optional[str] = "mean"
weight_by_size: Optional[str] = False
filter_list: Optional[Union[str,list]] = "none"
def __post_init__(self):
if isinstance(self.filter_list, str):
self.filter_list = [self.filter_list]
@dataclass
class GroupConfig(dict):
......@@ -58,10 +69,9 @@ class GroupConfig(dict):
group_alias: Optional[str] = None
task: Optional[Union[str, list]] = None
tag_to_task: Optional[str] = False
aggregate_metric: Optional[str] = False
aggregate_fn: Optional[str] = "mean"
weight_by_size: Optional[str] = False
metric_alias: Optional[str] = None # Still a placeholder
aggregate_metric_list: Optional[
Union[List[AggMetricConfig], AggMetricConfig, dict]
] = None
metadata: Optional[
dict
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
......@@ -72,6 +82,16 @@ class GroupConfig(dict):
def __setitem__(self, item, value):
return setattr(self, item, value)
def __post_init__(self):
if self.aggregate_metric_list is not None:
if isinstance(self.aggregate_metric_list, dict):
self.aggregate_metric_list = [self.aggregate_metric_list]
self.aggregate_metric_list = [
AggMetricConfig(**item) if isinstance(item, dict) else item
for item in self.aggregate_metric_list
]
def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
......
......@@ -616,13 +616,16 @@ def evaluate(
)
if (group_config is None) or (
group_config["aggregate_metric"] is False
group_config["aggregate_metric"] is None
):
results[group_or_task][" "] = " "
continue
show_group_table = (
show_group_table | group_config["aggregate_metric"]
if "aggregate_metric" in group_config:
agg_metric_list = group_config["aggregate_metric"]
show_group_table = show_group_table | bool(
group_config["aggregate_metric"]
)
task_list = _task_aggregation_list[group_or_task]
......@@ -656,26 +659,36 @@ def evaluate(
if metric in results[task]
]
# compute group's pooled metric and stderr
results[group_or_task][
metric
] = lm_eval.api.metrics.aggregate_subtask_metrics(
metrics,
sizes,
group_config["weight_by_size"],
)
# TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs:
results[group_or_task][stderr] = "N/A"
else:
results[group_or_task][
stderr
] = lm_eval.api.metrics.pooled_sample_stderr(
stderrs, sizes
)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
for metric_config in agg_metric_list:
for filter in metric_config["filter_list"]:
if metric != ",".join([metric_config["metric"], filter]):
continue
# compute group's pooled metric and stderr
if metric_config["aggregation"] == "mean":
aggregate_fn = (
lm_eval.api.metrics.aggregate_subtask_metrics
)
else:
aggregate_fn = metric_config["aggregation"]
results[group_or_task][metric] = aggregate_fn(
metrics,
sizes,
metric_config["weight_by_size"],
)
# TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs:
results[group_or_task][stderr] = "N/A"
else:
results[group_or_task][
stderr
] = lm_eval.api.metrics.pooled_sample_stderr(
stderrs, sizes
)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
results[group_or_task]["samples"] = sum(sizes)
group_metadata = group_config.get("metadata", None)
......@@ -683,6 +696,7 @@ def evaluate(
versions[group_or_task] = group_metadata.get(
"version", None
)
# print(results)
return results, versions, show_group_table, task_aggregation_list
results, versions, show_group_table, *_ = process_group(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment