Commit 80d0f412 authored by lintangsutawika's avatar lintangsutawika
Browse files

change how aggregate_metric is loaded

parent 0f095f79
...@@ -5,7 +5,7 @@ import random ...@@ -5,7 +5,7 @@ import random
import re import re
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass, field
from inspect import getsource from inspect import getsource
from typing import ( from typing import (
Any, Any,
...@@ -51,6 +51,17 @@ ALL_OUTPUT_TYPES = [ ...@@ -51,6 +51,17 @@ ALL_OUTPUT_TYPES = [
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
@dataclass
class AggMetricConfig(dict):
metric: Optional[str] = "acc"
metric_alias: Optional[str] = "acc"
aggregation: Optional[str] = "mean"
weight_by_size: Optional[str] = False
filter_list: Optional[Union[str,list]] = "none"
def __post_init__(self):
if isinstance(self.filter_list, str):
self.filter_list = [self.filter_list]
@dataclass @dataclass
class GroupConfig(dict): class GroupConfig(dict):
...@@ -58,10 +69,9 @@ class GroupConfig(dict): ...@@ -58,10 +69,9 @@ class GroupConfig(dict):
group_alias: Optional[str] = None group_alias: Optional[str] = None
task: Optional[Union[str, list]] = None task: Optional[Union[str, list]] = None
tag_to_task: Optional[str] = False tag_to_task: Optional[str] = False
aggregate_metric: Optional[str] = False aggregate_metric_list: Optional[
aggregate_fn: Optional[str] = "mean" Union[List[AggMetricConfig], AggMetricConfig, dict]
weight_by_size: Optional[str] = False ] = None
metric_alias: Optional[str] = None # Still a placeholder
metadata: Optional[ metadata: Optional[
dict dict
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks ] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
...@@ -72,6 +82,16 @@ class GroupConfig(dict): ...@@ -72,6 +82,16 @@ class GroupConfig(dict):
def __setitem__(self, item, value): def __setitem__(self, item, value):
return setattr(self, item, value) return setattr(self, item, value)
def __post_init__(self):
if self.aggregate_metric_list is not None:
if isinstance(self.aggregate_metric_list, dict):
self.aggregate_metric_list = [self.aggregate_metric_list]
self.aggregate_metric_list = [
AggMetricConfig(**item) if isinstance(item, dict) else item
for item in self.aggregate_metric_list
]
def to_dict(self, keep_callable: bool = False) -> dict: def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format. """dumps the current config as a dictionary object, as a printable format.
null fields will not be printed. null fields will not be printed.
......
...@@ -616,13 +616,16 @@ def evaluate( ...@@ -616,13 +616,16 @@ def evaluate(
) )
if (group_config is None) or ( if (group_config is None) or (
group_config["aggregate_metric"] is False group_config["aggregate_metric"] is None
): ):
results[group_or_task][" "] = " " results[group_or_task][" "] = " "
continue continue
show_group_table = ( if "aggregate_metric" in group_config:
show_group_table | group_config["aggregate_metric"] agg_metric_list = group_config["aggregate_metric"]
show_group_table = show_group_table | bool(
group_config["aggregate_metric"]
) )
task_list = _task_aggregation_list[group_or_task] task_list = _task_aggregation_list[group_or_task]
...@@ -656,26 +659,36 @@ def evaluate( ...@@ -656,26 +659,36 @@ def evaluate(
if metric in results[task] if metric in results[task]
] ]
# compute group's pooled metric and stderr for metric_config in agg_metric_list:
results[group_or_task][ for filter in metric_config["filter_list"]:
metric if metric != ",".join([metric_config["metric"], filter]):
] = lm_eval.api.metrics.aggregate_subtask_metrics( continue
metrics,
sizes, # compute group's pooled metric and stderr
group_config["weight_by_size"], if metric_config["aggregation"] == "mean":
) aggregate_fn = (
# TODO: calculate grouped metric using aggregation fn lm_eval.api.metrics.aggregate_subtask_metrics
if "N/A" in stderrs: )
results[group_or_task][stderr] = "N/A" else:
else: aggregate_fn = metric_config["aggregation"]
results[group_or_task][
stderr results[group_or_task][metric] = aggregate_fn(
] = lm_eval.api.metrics.pooled_sample_stderr( metrics,
stderrs, sizes sizes,
) metric_config["weight_by_size"],
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility )
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line: # TODO: calculate grouped metric using aggregation fn
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics) if "N/A" in stderrs:
results[group_or_task][stderr] = "N/A"
else:
results[group_or_task][
stderr
] = lm_eval.api.metrics.pooled_sample_stderr(
stderrs, sizes
)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
results[group_or_task]["samples"] = sum(sizes) results[group_or_task]["samples"] = sum(sizes)
group_metadata = group_config.get("metadata", None) group_metadata = group_config.get("metadata", None)
...@@ -683,6 +696,7 @@ def evaluate( ...@@ -683,6 +696,7 @@ def evaluate(
versions[group_or_task] = group_metadata.get( versions[group_or_task] = group_metadata.get(
"version", None "version", None
) )
# print(results)
return results, versions, show_group_table, task_aggregation_list return results, versions, show_group_table, task_aggregation_list
results, versions, show_group_table, *_ = process_group( results, versions, show_group_table, *_ = process_group(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment