Commit d2804132 authored by lintangsutawika's avatar lintangsutawika
Browse files

aggregates across tasks within the same group in addition to average accross task average

parent d4f62844
......@@ -211,15 +211,19 @@ def evaluate(
samples = collections.defaultdict(list)
# tracks all Instances/requests a model must generate output on.
requests = collections.defaultdict(list)
# Stores task scores based on task grouping.
# Aggregated task scores presented with groups
results_agg = collections.defaultdict(dict)
# Aggregated groups scores only
groups_agg = collections.defaultdict(dict)
# tracks if a task was chosen via user selecting a group containing it
# stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal
padding_requests = collections.defaultdict(int)
task_hierarchy = collections.defaultdict(list)
task_order = collections.defaultdict(int)
sample_agg_fn = collections.defaultdict(dict)
# get lists of each type of request
for task_name, task in task_dict.items():
......@@ -405,6 +409,35 @@ def evaluate(
vals = vals_torch
if lm.rank == 0:
### Get task ordering for correct sample-wide aggregation
group_to_task = {}
for group in task_hierarchy.keys():
if group not in task_order:
task_order[group] = 0
if len(task_hierarchy[group]) > 0:
group_to_task[group] = task_hierarchy[group].copy()
for task in task_hierarchy[group]:
if task in task_order:
task_order[task] += 1
else:
task_order[task] = 1 + task_order[group]
if task in task_hierarchy:
group_to_task[group].remove(task)
group_to_task[group].extend(task_hierarchy[task])
task_to_group = {}
for group in group_to_task:
for task in group_to_task[group]:
if task in task_to_group:
task_to_group[task].append(group)
else:
task_to_group[task] = [group]
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items():
......@@ -416,17 +449,22 @@ def evaluate(
else:
group_name = None
task_score = task.aggregation()[metric](items)
agg_fn = task.aggregation()[metric]
task_score = agg_fn(items)
if group_name is not None:
sample_metric_key = metric + "(sample avg)," + key
task_metric_key = metric + "(task avg)," + key
if task_metric_key in results[group_name]:
results[group_name][task_metric_key].append(task_score)
results[group_name][sample_metric_key].extend(items)
else:
results[group_name][task_metric_key] = [task_score]
results[group_name][sample_metric_key] = items
sample_metric_key = metric + "(sample agg)," + key
for grouping in task_to_group[task_name]:
if metric_key in results[grouping]:
results[grouping][metric_key].append(task_score)
else:
results[grouping][metric_key] = [task_score]
if sample_metric_key in results[grouping]:
results[grouping][sample_metric_key] += items
else:
results[grouping][sample_metric_key] = items.copy()
sample_agg_fn[grouping][sample_metric_key] = agg_fn
results[task_name][metric_key] = task_score
......@@ -446,23 +484,13 @@ def evaluate(
if bool(results):
for task_or_group in results.keys():
for metric in results[task_or_group].keys():
try:
print(task_or_group, metric, len(results[task_or_group][metric]))
except:
pass
if type(results[task_or_group][metric]) == list:
results[task_or_group][metric] = np.average(results[task_or_group][metric])
if "(sample agg)" in metric:
results[task_or_group][metric] = sample_agg_fn[task_or_group][metric](results[task_or_group][metric])
else:
results[task_or_group][metric] = np.average(results[task_or_group][metric])
versions[task_or_group] = "N/A"
for group in task_hierarchy.keys():
if group not in task_order:
task_order[group] = 0
for task in task_hierarchy[group]:
if task in task_order:
task_order[task] += 1
else:
task_order[task] = 1 + task_order[group]
for task_name, task in task_dict.items():
if type(task) == tuple:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment