Commit 62e4ea10 authored by lintangsutawika's avatar lintangsutawika
Browse files

simplified evaluator by removing unused variables

parent f8c2cfcb
...@@ -234,8 +234,7 @@ def evaluate( ...@@ -234,8 +234,7 @@ def evaluate(
padding_requests = collections.defaultdict(int) padding_requests = collections.defaultdict(int)
# store the hierarchy to do proper ordering # store the hierarchy to do proper ordering
task_hierarchy = collections.defaultdict(list) task_hierarchy = collections.defaultdict(list)
# store the ordering of tasks and groups # store task aliases
task_order = collections.defaultdict(int)
task_group_alias = collections.defaultdict(dict) task_group_alias = collections.defaultdict(dict)
# store num-fewshot value per task # store num-fewshot value per task
num_fewshot = collections.defaultdict(int) num_fewshot = collections.defaultdict(int)
...@@ -440,32 +439,6 @@ def evaluate( ...@@ -440,32 +439,6 @@ def evaluate(
vals = vals_torch vals = vals_torch
if lm.rank == 0: if lm.rank == 0:
### Get task ordering for correct sample-wide aggregation
group_to_task = {}
for group in task_hierarchy.keys():
if group not in task_order:
task_order[group] = 0
if len(task_hierarchy[group]) > 0:
group_to_task[group] = task_hierarchy[group].copy()
for task in task_hierarchy[group]:
if task in task_order:
task_order[task] += 1
else:
task_order[task] = 1 + task_order[group]
if task in task_hierarchy:
group_to_task[group].remove(task)
group_to_task[group].extend(task_hierarchy[task])
task_to_group = {}
for group in group_to_task:
for task in group_to_task[group]:
if task in task_to_group:
task_to_group[task].append(group)
else:
task_to_group[task] = [group]
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
...@@ -494,6 +467,8 @@ def evaluate( ...@@ -494,6 +467,8 @@ def evaluate(
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr" + "," + key] = stderr(items) results[task_name][metric + "_stderr" + "," + key] = stderr(items)
else:
results[task_name][metric + "_stderr" + "," + key] = 0
if bool(results): if bool(results):
for group, task_list in reversed(task_hierarchy.items()): for group, task_list in reversed(task_hierarchy.items()):
...@@ -551,37 +526,36 @@ def evaluate( ...@@ -551,37 +526,36 @@ def evaluate(
results[group]["samples"] = total_size results[group]["samples"] = total_size
def print_tasks(task_hierarchy, task_order, task_version, task_group_alias): def print_tasks(task_hierarchy, tab=0):
results_agg = collections.defaultdict(dict) results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict) groups_agg = collections.defaultdict(dict)
for group_name, task_list in task_hierarchy.items():
order = task_order[group_name]
results_agg[group_name] = results[group_name].copy()
results_agg[group_name]["tab"] = order
if (order < max(task_order.values())) and (len(task_list) > 0): (group_name, task_list), *_ = task_hierarchy.items()
groups_agg[group_name] = results[group_name].copy() task_list = sorted(task_list)
groups_agg[group_name]["tab"] = order
if task_list != []: results_agg[group_name] = results[group_name].copy()
for task in sorted(task_list): results_agg[group_name]["tab"] = tab
if task in task_hierarchy:
_task_hierarchy = {task: task_hierarchy[task]}
else:
_task_hierarchy = {task: []}
_results_agg, _groups_agg, task_version = print_tasks( if len(task_list) > 0:
_task_hierarchy, task_order, task_version, task_group_alias groups_agg[group_name] = results[group_name].copy()
) groups_agg[group_name]["tab"] = tab
results_agg = {**results_agg, **_results_agg} for task_name in task_list:
groups_agg = {**groups_agg, **_groups_agg} if task_name in task_hierarchy:
_task_hierarchy = {
**{task_name: task_hierarchy[task_name]},
**task_hierarchy,
}
else:
_task_hierarchy = {task_name: []}
return results_agg, groups_agg, task_version _results_agg, _groups_agg = print_tasks(_task_hierarchy, tab + 1)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
results_agg, groups_agg, versions = print_tasks( return results_agg, groups_agg
task_hierarchy, task_order, versions, task_group_alias
) results_agg, groups_agg = print_tasks(task_hierarchy)
for task in results_agg: for task in results_agg:
task_results = results_agg[task] task_results = results_agg[task]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment