"docker/Dockerfile" did not exist on "478602ba59c0bfe7ab9a094b9f1b7b33cfeecba4"
Unverified Commit 42730d90 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Simplify evaluator (#1126)

* save progress

* fixed issue with table only showing 1 group

* store aliases directly in results_agg

* removed unused parts
parent 8e87eff4
......@@ -234,9 +234,6 @@ def evaluate(
padding_requests = collections.defaultdict(int)
# store the hierarchy to do proper ordering
task_hierarchy = collections.defaultdict(list)
# store the ordering of tasks and groups
task_order = collections.defaultdict(int)
task_group_alias = collections.defaultdict(dict)
# store num-fewshot value per task
num_fewshot = collections.defaultdict(int)
......@@ -264,14 +261,14 @@ def evaluate(
num_fewshot[task_name] = n_shot
if "task_alias" in configs[task_name]:
task_group_alias[task_name] = configs[task_name]["task_alias"]
results[task_name]["alias"] = configs[task_name]["task_alias"]
if (
("group_alias" in configs[task_name])
and (group_name not in task_group_alias)
and (group_name not in results)
and (group_name is not None)
):
task_group_alias[group_name] = configs[task_name]["group_alias"]
results[group_name]["alias"] = configs[task_name]["group_alias"]
if limit is not None:
if task.has_test_docs():
......@@ -440,32 +437,6 @@ def evaluate(
vals = vals_torch
if lm.rank == 0:
### Get task ordering for correct sample-wide aggregation
group_to_task = {}
for group in task_hierarchy.keys():
if group not in task_order:
task_order[group] = 0
if len(task_hierarchy[group]) > 0:
group_to_task[group] = task_hierarchy[group].copy()
for task in task_hierarchy[group]:
if task in task_order:
task_order[task] += 1
else:
task_order[task] = 1 + task_order[group]
if task in task_hierarchy:
group_to_task[group].remove(task)
group_to_task[group].extend(task_hierarchy[task])
task_to_group = {}
for group in group_to_task:
for task in group_to_task[group]:
if task in task_to_group:
task_to_group[task].append(group)
else:
task_to_group[task] = [group]
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
......@@ -505,7 +476,10 @@ def evaluate(
total_size = 0
for task in task_list:
metrics = results[task]
metrics = results[task].copy()
if "alias" in metrics:
metrics.pop("alias")
current_size = metrics.pop("samples")
# TODO: There should be a way for users
......@@ -553,71 +527,77 @@ def evaluate(
results[group]["samples"] = total_size
def print_tasks(task_hierarchy, task_order, task_version, task_group_alias):
def print_tasks(task_hierarchy, results, tab=0):
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
for group_name, task_list in task_hierarchy.items():
order = task_order[group_name]
results_agg[group_name] = results[group_name].copy()
results_agg[group_name]["tab"] = order
if (order < max(task_order.values())) and (len(task_list) > 0):
groups_agg[group_name] = results[group_name].copy()
groups_agg[group_name]["tab"] = order
(group_name, task_list), *_ = task_hierarchy.items()
task_list = sorted(task_list)
if task_list != []:
for task in sorted(task_list):
if task in task_hierarchy:
_task_hierarchy = {task: task_hierarchy[task]}
else:
_task_hierarchy = {task: []}
_results_agg, _groups_agg, task_version = print_tasks(
_task_hierarchy, task_order, task_version, task_group_alias
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg, task_version
results_agg, groups_agg, versions = print_tasks(
task_hierarchy, task_order, versions, task_group_alias
)
results_agg[group_name] = results[group_name].copy()
# results_agg[group_name]["tab"] = tab
if "samples" in results_agg[group_name]:
results_agg[group_name].pop("samples")
for task in results_agg:
task_results = results_agg[task]
tab_string = " " * tab + "- " if tab > 0 else ""
if "samples" in task_results:
task_results.pop("samples")
tab_string = ""
if "tab" in task_results:
tab = task_results.pop("tab")
tab_string = " " * tab + "- " if tab > 0 else ""
if task in task_group_alias:
task_alias = task_group_alias[task]
results_agg[task]["alias"] = tab_string + task_alias
if "alias" in results_agg[group_name]:
results_agg[group_name]["alias"] = (
tab_string + results_agg[group_name]["alias"]
)
else:
results_agg[task]["alias"] = tab_string + task
for group in groups_agg:
group_results = groups_agg[group]
if "samples" in group_results:
group_results.pop("samples")
results_agg[group_name]["alias"] = tab_string + group_name
tab_string = ""
if "tab" in group_results:
tab = group_results.pop("tab")
tab_string = " " * tab + "- " if tab > 0 else ""
if len(task_list) > 0:
groups_agg[group_name] = results[group_name].copy()
# groups_agg[group_name]["tab"] = tab
if "samples" in groups_agg[group_name]:
groups_agg[group_name].pop("samples")
if group in task_group_alias:
group_alias = task_group_alias[group]
groups_agg[group]["alias"] = tab_string + group_alias
else:
groups_agg[group]["alias"] = tab_string + group
if "alias" in groups_agg[group_name]:
groups_agg[group_name]["alias"] = (
tab_string + groups_agg[group_name]["alias"]
)
else:
groups_agg[group_name]["alias"] = tab_string + group_name
for task_name in task_list:
if task_name in task_hierarchy:
_task_hierarchy = {
**{task_name: task_hierarchy[task_name]},
**task_hierarchy,
}
else:
_task_hierarchy = {
**{task_name: []},
**task_hierarchy,
}
_results_agg, _groups_agg = print_tasks(
_task_hierarchy, results, tab + 1
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
all_tasks_list = list(task_hierarchy.keys())
left_tasks_list = []
while True:
add_tasks_list = list(k for k in results_agg.keys())
left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
if len(left_tasks_list) == 0:
break
_task_hierarchy = {
k: v for k, v in task_hierarchy.items() if k in left_tasks_list
}
_results_agg, _groups_agg = print_tasks(_task_hierarchy, results)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
for group_name, task_list in task_hierarchy.items():
if task_list != []:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment