Commit 1dc8f96f authored by lintangsutawika's avatar lintangsutawika
Browse files

default to weighted averaging

parent 92f25463
......@@ -219,6 +219,7 @@ def evaluate(
padding_requests = collections.defaultdict(int)
# store the hierarchy to do proper ordering
task_hierarchy = collections.defaultdict(list)
group_hierarchy = collections.defaultdict(list)
# store the ordering of tasks and groups
task_order = collections.defaultdict(int)
# store the aggregation for aggregating across tasks in the same group
......@@ -450,22 +451,26 @@ def evaluate(
agg_fn = task.aggregation()[metric]
task_score = agg_fn(items)
if group_name is not None:
sample_metric_key = metric + "(sample agg)," + key
for grouping in task_to_group[task_name]:
if metric_key in results[grouping]:
results[grouping][metric_key].append(task_score)
else:
results[grouping][metric_key] = [task_score]
if sample_metric_key in results[grouping]:
results[grouping][sample_metric_key] += items
else:
results[grouping][sample_metric_key] = items.copy()
sample_agg_fn[grouping][sample_metric_key] = agg_fn
task_size = len(items)
# if group_name is not None:
# sample_metric_key = metric + "(sample agg)," + key
# for grouping in task_to_group[task_name]:
# if metric_key in results[grouping]:
# results[grouping][metric_key].append(task_score)
# results[grouping]["size"].append(task_size)
# else:
# results[grouping][metric_key] = [task_score]
# results[grouping]["size"] = [task_size]
# if sample_metric_key in results[grouping]:
# results[grouping][sample_metric_key] += items
# else:
# results[grouping][sample_metric_key] = items.copy()
# sample_agg_fn[grouping][sample_metric_key] = agg_fn
results[task_name][metric_key] = task_score
results[task_name]["size"] = task_size
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
......@@ -481,18 +486,36 @@ def evaluate(
results[task_name][metric + "_stderr" + "," + key] = stderr(items)
if bool(results):
for task_or_group in results.keys():
for metric in results[task_or_group].keys():
if type(results[task_or_group][metric]) == list:
if "(sample agg)" in metric:
results[task_or_group][metric] = sample_agg_fn[
task_or_group
][metric](results[task_or_group][metric])
for group, task_list in reversed(task_hierarchy.items()):
versions[group] = "N/A"
task_score_dict = {}
total_size = 0
for task in task_list:
metrics = results[task]
if "size" in metrics:
current_size = metrics.pop("size")
else:
current_size = 1
for metric in [key for key in metrics.keys()]:
if "_stderr" in metric:
print(metric)
metric_score = results[task][metric]
if metric in results[group]:
results[group][metric] = (results[group][metric]*total_size + metric_score*current_size)/(total_size+current_size)
else:
results[task_or_group][metric] = np.average(
results[task_or_group][metric]
)
versions[task_or_group] = "N/A"
results[group][metric] = metric_score
# Different formula for agg stderr
total_size += current_size
for task_name, task in task_dict.items():
if type(task) == tuple:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment