Commit 2d96a8c8 authored by lintangsutawika's avatar lintangsutawika
Browse files

add condition if --task is not a benchmark

parent ed304c1d
...@@ -398,6 +398,7 @@ def evaluate( ...@@ -398,6 +398,7 @@ def evaluate(
# | word_perplexity # | word_perplexity
# | byte_perplexity # | byte_perplexity
# | bits_per_byte # | bits_per_byte
if bool(task_groups):
group_name = task_groups[task_name] group_name = task_groups[task_name]
if metric not in aggregate[group_name]: if metric not in aggregate[group_name]:
aggregate[group_name][metric] = [task_score] aggregate[group_name][metric] = [task_score]
...@@ -417,6 +418,7 @@ def evaluate( ...@@ -417,6 +418,7 @@ def evaluate(
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr" + "," + key] = stderr(items) results[task_name][metric + "_stderr" + "," + key] = stderr(items)
if not bool(aggregate):
for group in aggregate.keys(): for group in aggregate.keys():
for metric in aggregate[group].keys(): for metric in aggregate[group].keys():
aggregate[group][metric] = np.average(aggregate[group][metric]) aggregate[group][metric] = np.average(aggregate[group][metric])
...@@ -424,7 +426,7 @@ def evaluate( ...@@ -424,7 +426,7 @@ def evaluate(
results_dict = { results_dict = {
"results": dict(results), "results": dict(results),
"aggregate": dict(aggregate), **({"aggregate": dict(aggregate)} if bool(aggregate) else {}),
"configs": dict(configs), "configs": dict(configs),
"versions": dict(versions), "versions": dict(versions),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment