Commit 6da6d187 authored by lintangsutawika's avatar lintangsutawika
Browse files

add a group config that allows disabling table for group score and group aggregate in general

parent 2a2566e6
......@@ -10,6 +10,7 @@ import torch
import lm_eval.api.metrics
import lm_eval.api.registry
import lm_eval.api.task
import lm_eval.models
from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import (
......@@ -211,7 +212,7 @@ def simple_evaluate(
task_obj = task_dict[task_name]
if isinstance(task_obj, tuple):
_, task_obj = task_obj
if task_obj is None:
if isinstance(task_obj, lm_eval.api.task.ConfigurableTask) is False:
continue
if task_obj.get_config("output_type") == "generate_until":
......@@ -483,13 +484,24 @@ def evaluate(
### Calculate group metrics ###
if bool(results):
for group, task_list in reversed(task_hierarchy.items()):
show_group_table = False
for group, group_info in reversed(task_hierarchy.items()):
task_list = group_info["tasks"]
if len(task_list) == 0:
# task_hierarchy entries are either
# `group_name: [subtask1, subtask2, ...]`
# or `task_name: []`.
# we only want to operate on groups here.
continue
group_config = group_info["config"] if "config" in group_info else {}
aggregate_metric = group_config["aggregate_metric"] if "aggregate_metric" in group_config else False
show_group_table = show_group_table | aggregate_metric
weight_by_size = group_config["weight_by_size"] if "weight_by_size" in group_config else False
if aggregate_metric is False:
results[group][" "] = " "
continue
metric_list = list(
{
key
......@@ -545,14 +557,15 @@ def evaluate(
break
_task_hierarchy = {
k: v for k, v in task_hierarchy.items() if k in left_tasks_list
k: v["tasks"] for k, v in task_hierarchy.items() if k in left_tasks_list
}
_results_agg, _groups_agg = prepare_print_tasks(_task_hierarchy, results)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
for group_name, task_list in task_hierarchy.items():
for group_name, group_info in task_hierarchy.items():
task_list = group_info["tasks"]
if task_list:
num_fewshot[group_name] = num_fewshot[
task_list[0]
......@@ -560,7 +573,7 @@ def evaluate(
results_dict = {
"results": dict(results_agg.items()),
**({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
**({"groups": dict(groups_agg.items())} if (bool(groups_agg) & show_group_table) else {}),
"group_subtasks": dict(reversed(task_hierarchy.items())),
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
......
......@@ -119,13 +119,26 @@ class TaskOutput:
def get_task_list(task_dict: dict) -> Tuple[Dict[str, list], List[TaskOutput]]:
task_hierarchy = collections.defaultdict(list)
outputs = list(TaskOutput.from_taskdict(x, y) for x, y in task_dict.items())
for task_output in outputs:
# task_hierarchy = collections.defaultdict(list)
task_hierarchy = collections.defaultdict(lambda: collections.defaultdict(list))
outputs = []
for x, y in task_dict.items():
group, task_obj = y
if isinstance(task_obj, dict):
task_output = TaskOutput.from_taskdict(x, (group, None))
task_config = task_obj
else:
task_output = TaskOutput.from_taskdict(x, y)
task_config = task_obj.config
outputs.append(task_output)
if group_name := task_output.group_name:
task_hierarchy[group_name].append(task_output.task_name)
task_hierarchy[group_name]["tasks"].append(task_output.task_name)
if "group_config" in task_config:
task_hierarchy[group_name]["config"] = task_config["group_config"]
else:
task_hierarchy[task_output.task_name] = []
task_hierarchy[task_output.task_name]["tasks"] = []
# returns task_hierarchy tracking which groups contain which subtasks,
# and a list of TaskOutput classes for each non-group subtask
return task_hierarchy, [x for x in outputs if x.task]
......
......@@ -240,7 +240,9 @@ class TaskManager:
all_subtasks = {}
if parent_name is not None:
all_subtasks = {group_name: (parent_name, None)}
# all_subtasks = {group_name: (parent_name, None)}
parent_group_config = self._get_config(parent_name)
all_subtasks = {group_name: (parent_name, parent_group_config)}
fn = partial(
self._load_individual_task_or_group,
......
......@@ -4,3 +4,7 @@ task:
- mmlu_other
- mmlu_social_sciences
- mmlu_humanities
group_config:
aggregate_metric: True
aggregate_fn: mean
weight_by_size: True
......@@ -252,14 +252,15 @@ def make_table(result_dict, column: str = "results"):
m, _, f = mf.partition(",")
if m.endswith("_stderr"):
continue
if v != " ":
v = "%.4f" % v
if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f]
if se != "N/A":
se = "%.4f" % se
values.append([k, version, f, n, m, "%.4f" % v, "±", se])
values.append([k, version, f, n, m, v, "±", se])
else:
values.append([k, version, f, n, m, "%.4f" % v, "", ""])
values.append([k, version, f, n, m, v, "", ""])
k = ""
version = ""
md_writer.value_matrix = values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment