Commit 6da6d187 authored by lintangsutawika's avatar lintangsutawika
Browse files

add a group config that allows disabling table for group score and group aggregate in general

parent 2a2566e6
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import lm_eval.api.metrics import lm_eval.api.metrics
import lm_eval.api.registry import lm_eval.api.registry
import lm_eval.api.task
import lm_eval.models import lm_eval.models
from lm_eval.caching.cache import delete_cache from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import ( from lm_eval.evaluator_utils import (
...@@ -211,7 +212,7 @@ def simple_evaluate( ...@@ -211,7 +212,7 @@ def simple_evaluate(
task_obj = task_dict[task_name] task_obj = task_dict[task_name]
if isinstance(task_obj, tuple): if isinstance(task_obj, tuple):
_, task_obj = task_obj _, task_obj = task_obj
if task_obj is None: if isinstance(task_obj, lm_eval.api.task.ConfigurableTask) is False:
continue continue
if task_obj.get_config("output_type") == "generate_until": if task_obj.get_config("output_type") == "generate_until":
...@@ -483,13 +484,24 @@ def evaluate( ...@@ -483,13 +484,24 @@ def evaluate(
### Calculate group metrics ### ### Calculate group metrics ###
if bool(results): if bool(results):
for group, task_list in reversed(task_hierarchy.items()): show_group_table = False
for group, group_info in reversed(task_hierarchy.items()):
task_list = group_info["tasks"]
if len(task_list) == 0: if len(task_list) == 0:
# task_hierarchy entries are either # task_hierarchy entries are either
# `group_name: [subtask1, subtask2, ...]` # `group_name: [subtask1, subtask2, ...]`
# or `task_name: []`. # or `task_name: []`.
# we only want to operate on groups here. # we only want to operate on groups here.
continue continue
group_config = group_info["config"] if "config" in group_info else {}
aggregate_metric = group_config["aggregate_metric"] if "aggregate_metric" in group_config else False
show_group_table = show_group_table | aggregate_metric
weight_by_size = group_config["weight_by_size"] if "weight_by_size" in group_config else False
if aggregate_metric is False:
results[group][" "] = " "
continue
metric_list = list( metric_list = list(
{ {
key key
...@@ -545,14 +557,15 @@ def evaluate( ...@@ -545,14 +557,15 @@ def evaluate(
break break
_task_hierarchy = { _task_hierarchy = {
k: v for k, v in task_hierarchy.items() if k in left_tasks_list k: v["tasks"] for k, v in task_hierarchy.items() if k in left_tasks_list
} }
_results_agg, _groups_agg = prepare_print_tasks(_task_hierarchy, results) _results_agg, _groups_agg = prepare_print_tasks(_task_hierarchy, results)
results_agg = {**results_agg, **_results_agg} results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg} groups_agg = {**groups_agg, **_groups_agg}
for group_name, task_list in task_hierarchy.items(): for group_name, group_info in task_hierarchy.items():
task_list = group_info["tasks"]
if task_list: if task_list:
num_fewshot[group_name] = num_fewshot[ num_fewshot[group_name] = num_fewshot[
task_list[0] task_list[0]
...@@ -560,7 +573,7 @@ def evaluate( ...@@ -560,7 +573,7 @@ def evaluate(
results_dict = { results_dict = {
"results": dict(results_agg.items()), "results": dict(results_agg.items()),
**({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}), **({"groups": dict(groups_agg.items())} if (bool(groups_agg) & show_group_table) else {}),
"group_subtasks": dict(reversed(task_hierarchy.items())), "group_subtasks": dict(reversed(task_hierarchy.items())),
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
......
...@@ -119,13 +119,26 @@ class TaskOutput: ...@@ -119,13 +119,26 @@ class TaskOutput:
def get_task_list(task_dict: dict) -> Tuple[Dict[str, list], List[TaskOutput]]: def get_task_list(task_dict: dict) -> Tuple[Dict[str, list], List[TaskOutput]]:
task_hierarchy = collections.defaultdict(list) # task_hierarchy = collections.defaultdict(list)
outputs = list(TaskOutput.from_taskdict(x, y) for x, y in task_dict.items()) task_hierarchy = collections.defaultdict(lambda: collections.defaultdict(list))
for task_output in outputs: outputs = []
for x, y in task_dict.items():
group, task_obj = y
if isinstance(task_obj, dict):
task_output = TaskOutput.from_taskdict(x, (group, None))
task_config = task_obj
else:
task_output = TaskOutput.from_taskdict(x, y)
task_config = task_obj.config
outputs.append(task_output)
if group_name := task_output.group_name: if group_name := task_output.group_name:
task_hierarchy[group_name].append(task_output.task_name) task_hierarchy[group_name]["tasks"].append(task_output.task_name)
if "group_config" in task_config:
task_hierarchy[group_name]["config"] = task_config["group_config"]
else: else:
task_hierarchy[task_output.task_name] = [] task_hierarchy[task_output.task_name]["tasks"] = []
# returns task_hierarchy tracking which groups contain which subtasks, # returns task_hierarchy tracking which groups contain which subtasks,
# and a list of TaskOutput classes for each non-group subtask # and a list of TaskOutput classes for each non-group subtask
return task_hierarchy, [x for x in outputs if x.task] return task_hierarchy, [x for x in outputs if x.task]
......
...@@ -240,7 +240,9 @@ class TaskManager: ...@@ -240,7 +240,9 @@ class TaskManager:
all_subtasks = {} all_subtasks = {}
if parent_name is not None: if parent_name is not None:
all_subtasks = {group_name: (parent_name, None)} # all_subtasks = {group_name: (parent_name, None)}
parent_group_config = self._get_config(parent_name)
all_subtasks = {group_name: (parent_name, parent_group_config)}
fn = partial( fn = partial(
self._load_individual_task_or_group, self._load_individual_task_or_group,
......
...@@ -4,3 +4,7 @@ task: ...@@ -4,3 +4,7 @@ task:
- mmlu_other - mmlu_other
- mmlu_social_sciences - mmlu_social_sciences
- mmlu_humanities - mmlu_humanities
group_config:
aggregate_metric: True
aggregate_fn: mean
weight_by_size: True
...@@ -252,14 +252,15 @@ def make_table(result_dict, column: str = "results"): ...@@ -252,14 +252,15 @@ def make_table(result_dict, column: str = "results"):
m, _, f = mf.partition(",") m, _, f = mf.partition(",")
if m.endswith("_stderr"): if m.endswith("_stderr"):
continue continue
if v != " ":
v = "%.4f" % v
if m + "_stderr" + "," + f in dic: if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f] se = dic[m + "_stderr" + "," + f]
if se != "N/A": if se != "N/A":
se = "%.4f" % se se = "%.4f" % se
values.append([k, version, f, n, m, "%.4f" % v, "±", se]) values.append([k, version, f, n, m, v, "±", se])
else: else:
values.append([k, version, f, n, m, "%.4f" % v, "", ""]) values.append([k, version, f, n, m, v, "", ""])
k = "" k = ""
version = "" version = ""
md_writer.value_matrix = values md_writer.value_matrix = values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment