Commit 18a7c8b1 authored by lintangsutawika's avatar lintangsutawika
Browse files

attach group identifier to task for aggregation

parent ef7588b6
...@@ -191,10 +191,21 @@ def evaluate( ...@@ -191,10 +191,21 @@ def evaluate(
samples = collections.defaultdict(list) samples = collections.defaultdict(list)
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
aggregate = collections.defaultdict(dict) aggregate = collections.defaultdict(dict)
task_groups = collections.defaultdict(dict)
padding_requests = collections.defaultdict(int) padding_requests = collections.defaultdict(int)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple:
group, task = task
# if group in task_groups:
# task_groups[group].append(task_name)
# else:
# task_groups[group] = [task_name]
task_groups[task_name] = group
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config()) configs[task_name] = dict(task.dump_config())
...@@ -269,6 +280,8 @@ def evaluate( ...@@ -269,6 +280,8 @@ def evaluate(
### Postprocess outputs ### ### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately) # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple:
group, task = task
task.apply_filters() task.apply_filters()
### Collect values of metrics on all datapoints ### ### Collect values of metrics on all datapoints ###
...@@ -276,6 +289,8 @@ def evaluate( ...@@ -276,6 +289,8 @@ def evaluate(
# unpack results and sort back in order and return control to Task # unpack results and sort back in order and return control to Task
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple:
group, task = task
# TODO: make it possible to use a different metric per filter # TODO: make it possible to use a different metric per filter
# iterate over different filters used # iterate over different filters used
for key in task.instances[0].filtered_resps.keys(): for key in task.instances[0].filtered_resps.keys():
...@@ -361,6 +376,8 @@ def evaluate( ...@@ -361,6 +376,8 @@ def evaluate(
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
if type(task) == tuple:
group, task = task
task_score = task.aggregation()[metric](items) task_score = task.aggregation()[metric](items)
results[task_name][metric + "," + key] = task_score results[task_name][metric + "," + key] = task_score
...@@ -373,10 +390,11 @@ def evaluate( ...@@ -373,10 +390,11 @@ def evaluate(
# | word_perplexity # | word_perplexity
# | byte_perplexity # | byte_perplexity
# | bits_per_byte # | bits_per_byte
if metric not in aggregate: group_name = task_groups[task_name]
aggregate[metric] = [task_score] if metric not in aggregate[group_name]:
aggregate[group_name][metric] = [task_score]
else: else:
aggregate[metric].append(task_score) aggregate[group_name][metric].append(task_score)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
...@@ -391,9 +409,10 @@ def evaluate( ...@@ -391,9 +409,10 @@ def evaluate(
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr" + "," + key] = stderr(items) results[task_name][metric + "_stderr" + "," + key] = stderr(items)
for metric in aggregate.keys(): for group in aggregate.keys():
results["Aggregate"][metric] = np.average(aggregate[metric]) for metric in aggregate[group].keys():
versions["Aggregate"] = "N/A" aggregate[group][metric] = np.average(aggregate[group][metric])
versions[group] = "N/A"
results_dict = { results_dict = {
"results": dict(results), "results": dict(results),
......
...@@ -128,11 +128,15 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -128,11 +128,15 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
if isinstance(task_element, str): if isinstance(task_element, str):
if task_element in GROUP_REGISTRY: if task_element in GROUP_REGISTRY:
group_name = task_element
for task_name in GROUP_REGISTRY[task_element]: for task_name in GROUP_REGISTRY[task_element]:
if task_name not in task_name_from_registry_dict: if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = { task_name_from_registry_dict = {
**task_name_from_registry_dict, **task_name_from_registry_dict,
task_name: get_task(task_name=task_name, config=config), task_name: (
group_name,
get_task(task_name=task_name, config=config),
),
} }
else: else:
task_name = task_element task_name = task_element
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment