Commit d2804132 authored by lintangsutawika's avatar lintangsutawika
Browse files

aggregates across tasks within the same group in addition to average accross task average

parent d4f62844
...@@ -211,15 +211,19 @@ def evaluate( ...@@ -211,15 +211,19 @@ def evaluate(
samples = collections.defaultdict(list) samples = collections.defaultdict(list)
# tracks all Instances/requests a model must generate output on. # tracks all Instances/requests a model must generate output on.
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
# Stores task scores based on task grouping. # Aggregated task scores presented with groups
results_agg = collections.defaultdict(dict) results_agg = collections.defaultdict(dict)
# Aggregated groups scores only
groups_agg = collections.defaultdict(dict) groups_agg = collections.defaultdict(dict)
# tracks if a task was chosen via user selecting a group containing it
# stores the amount to pad out reqs per req. type so that # stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal # number of fwd passes per distributed rank is equal
padding_requests = collections.defaultdict(int) padding_requests = collections.defaultdict(int)
task_hierarchy = collections.defaultdict(list) task_hierarchy = collections.defaultdict(list)
task_order = collections.defaultdict(int) task_order = collections.defaultdict(int)
sample_agg_fn = collections.defaultdict(dict)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
...@@ -405,6 +409,35 @@ def evaluate( ...@@ -405,6 +409,35 @@ def evaluate(
vals = vals_torch vals = vals_torch
if lm.rank == 0: if lm.rank == 0:
### Get task ordering for correct sample-wide aggregation
group_to_task = {}
for group in task_hierarchy.keys():
if group not in task_order:
task_order[group] = 0
if len(task_hierarchy[group]) > 0:
group_to_task[group] = task_hierarchy[group].copy()
for task in task_hierarchy[group]:
if task in task_order:
task_order[task] += 1
else:
task_order[task] = 1 + task_order[group]
if task in task_hierarchy:
group_to_task[group].remove(task)
group_to_task[group].extend(task_hierarchy[task])
task_to_group = {}
for group in group_to_task:
for task in group_to_task[group]:
if task in task_to_group:
task_to_group[task].append(group)
else:
task_to_group[task] = [group]
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
...@@ -416,17 +449,22 @@ def evaluate( ...@@ -416,17 +449,22 @@ def evaluate(
else: else:
group_name = None group_name = None
task_score = task.aggregation()[metric](items) agg_fn = task.aggregation()[metric]
task_score = agg_fn(items)
if group_name is not None: if group_name is not None:
sample_metric_key = metric + "(sample avg)," + key sample_metric_key = metric + "(sample agg)," + key
task_metric_key = metric + "(task avg)," + key for grouping in task_to_group[task_name]:
if task_metric_key in results[group_name]: if metric_key in results[grouping]:
results[group_name][task_metric_key].append(task_score) results[grouping][metric_key].append(task_score)
results[group_name][sample_metric_key].extend(items) else:
else: results[grouping][metric_key] = [task_score]
results[group_name][task_metric_key] = [task_score]
results[group_name][sample_metric_key] = items if sample_metric_key in results[grouping]:
results[grouping][sample_metric_key] += items
else:
results[grouping][sample_metric_key] = items.copy()
sample_agg_fn[grouping][sample_metric_key] = agg_fn
results[task_name][metric_key] = task_score results[task_name][metric_key] = task_score
...@@ -446,23 +484,13 @@ def evaluate( ...@@ -446,23 +484,13 @@ def evaluate(
if bool(results): if bool(results):
for task_or_group in results.keys(): for task_or_group in results.keys():
for metric in results[task_or_group].keys(): for metric in results[task_or_group].keys():
try:
print(task_or_group, metric, len(results[task_or_group][metric]))
except:
pass
if type(results[task_or_group][metric]) == list: if type(results[task_or_group][metric]) == list:
results[task_or_group][metric] = np.average(results[task_or_group][metric]) if "(sample agg)" in metric:
results[task_or_group][metric] = sample_agg_fn[task_or_group][metric](results[task_or_group][metric])
else:
results[task_or_group][metric] = np.average(results[task_or_group][metric])
versions[task_or_group] = "N/A" versions[task_or_group] = "N/A"
for group in task_hierarchy.keys():
if group not in task_order:
task_order[group] = 0
for task in task_hierarchy[group]:
if task in task_order:
task_order[task] += 1
else:
task_order[task] = 1 + task_order[group]
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment