Commit 6117c507 authored by lintangsutawika's avatar lintangsutawika
Browse files

changed how metrics are calculated

parent 028f04c7
...@@ -370,7 +370,7 @@ def evaluate( ...@@ -370,7 +370,7 @@ def evaluate(
# subset instances to only this document id ; sort by idx # subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
requests.sort(key=lambda x: x.idx) requests.sort(key=lambda x: x.idx)
metrics = task.process_results( items = task.process_results(
doc, [req.filtered_resps[key] for req in requests] doc, [req.filtered_resps[key] for req in requests]
) )
if log_samples: if log_samples:
...@@ -383,10 +383,11 @@ def evaluate( ...@@ -383,10 +383,11 @@ def evaluate(
"resps": [req.resps for req in requests], "resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[key] for req in requests], "filtered_resps": [req.filtered_resps[key] for req in requests],
} }
example.update(metrics) example.update(items)
samples[task_name].append(example) samples[task_name].append(example)
for metric, value in metrics.items(): vals[(task_name, key)].append(items)
vals[(task_name, key, metric)].append(value) # for metric, value in results.items():
# vals[(task_name, key, metric)].append(value)
if lm.world_size > 1: if lm.world_size > 1:
# if multigpu, then gather data across all ranks # if multigpu, then gather data across all ranks
...@@ -399,7 +400,8 @@ def evaluate( ...@@ -399,7 +400,8 @@ def evaluate(
# then collect metrics across all ranks # then collect metrics across all ranks
vals_torch = collections.defaultdict(list) vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items(): # for (task_name, key, metric), items in vals.items():
for (task_name, key), items in vals.items():
numitem = 0 numitem = 0
if type(items[0]) == tuple: if type(items[0]) == tuple:
numitem = len(items[0]) numitem = len(items[0])
...@@ -435,7 +437,8 @@ def evaluate( ...@@ -435,7 +437,8 @@ def evaluate(
gathered_item = [tuple(g) for g in gathered_item] gathered_item = [tuple(g) for g in gathered_item]
if lm.rank == 0: if lm.rank == 0:
vals_torch[(task_name, key, metric)] = gathered_item # vals_torch[(task_name, key, metric)] = gathered_item
vals_torch[(task_name, key)] = gathered_item
vals = vals_torch vals = vals_torch
...@@ -469,18 +472,19 @@ def evaluate( ...@@ -469,18 +472,19 @@ def evaluate(
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): # for (task_name, key, metric), items in vals.items():
for (task_name, key), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
metric_key = metric + "," + key # metric_key = metric + "," + key
if type(task) == tuple: if type(task) == tuple:
group_name, task = task group_name, task = task
else: else:
group_name = None group_name = None
agg_fn = task.aggregation()[metric] for metric_key, metric_fn in task.aggregation().items():
results[task_name][metric_key] = agg_fn(items) results[task_name][metric_key] = metric_fn(*list(zip(*items)))
results[task_name]["samples"] = len(items) results[task_name]["samples"] = len(items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment