Commit 5a98162d authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed size configuration

parent 6da6d187
......@@ -494,11 +494,11 @@ def evaluate(
# we only want to operate on groups here.
continue
group_config = group_info["config"] if "config" in group_info else {}
aggregate_metric = group_config["aggregate_metric"] if "aggregate_metric" in group_config else False
show_group_table = show_group_table | aggregate_metric
weight_by_size = group_config["weight_by_size"] if "weight_by_size" in group_config else False
if aggregate_metric is False:
group_config = lm_eval.api.task.GroupConfig(
**group_info["config"] if "config" in group_info else {}
)
show_group_table = show_group_table | group_config["aggregate_metric"]
if group_config["aggregate_metric"] is False:
results[group][" "] = " "
continue
......@@ -533,7 +533,10 @@ def evaluate(
# compute group's pooled metric and stderr
results[group][
metric
] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
] = lm_eval.api.metrics.aggregate_subtask_metrics(
metrics,
sizes if group_config["weight_by_size"] else [1] * len(sizes),
)
# TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs:
results[group][stderr] = "N/A"
......@@ -573,7 +576,11 @@ def evaluate(
results_dict = {
"results": dict(results_agg.items()),
**({"groups": dict(groups_agg.items())} if (bool(groups_agg) & show_group_table) else {}),
**(
{"groups": dict(groups_agg.items())}
if (bool(groups_agg) & show_group_table)
else {}
),
"group_subtasks": dict(reversed(task_hierarchy.items())),
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment