"docs/license.md" did not exist on "61545bda3568b38fbf08f218a6e4091da83fc32e"
Commit 81495b0a authored by lintangsutawika's avatar lintangsutawika
Browse files

num fewshot is printed in table

parent 8ffd2630
......@@ -217,6 +217,8 @@ def evaluate(
# store the ordering of tasks and groups
task_order = collections.defaultdict(int)
task_group_alias = collections.defaultdict(dict)
# store num-fewshot value per task
num_fewshot = collections.defaultdict(int)
# get lists of each type of request
for task_name, task in task_dict.items():
......@@ -234,6 +236,12 @@ def evaluate(
versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config())
if "num_fewshot" in configs[task_name]:
n_shot = configs[task_name]["num_fewshot"]
else:
n_shot = -1
num_fewshot[task_name] = n_shot
if "task_alias" in configs[task_name]:
task_group_alias[task_name] = configs[task_name]["task_alias"]
......@@ -564,6 +572,7 @@ def evaluate(
_results_agg = collections.defaultdict(dict)
_versions = collections.defaultdict(dict)
_num_fewshot = collections.defaultdict(int)
for task in results_agg:
task_results = results_agg[task]
......@@ -579,11 +588,14 @@ def evaluate(
task_alias = task_group_alias[task]
_results_agg[tab_string + task_alias] = task_results
_versions[tab_string + task_alias] = versions[task]
_num_fewshot[tab_string + task_alias] = num_fewshot[task]
else:
_results_agg[tab_string + task] = task_results
_versions[tab_string + task] = versions[task]
_num_fewshot[tab_string + task] = num_fewshot[task]
results_agg = _results_agg
versions = _versions
num_fewshot = _num_fewshot
_groups_agg = collections.defaultdict(dict)
for group in groups_agg:
......@@ -609,6 +621,7 @@ def evaluate(
**({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())),
}
if log_samples:
results_dict["samples"] = dict(samples)
......
......@@ -286,6 +286,7 @@ def make_table(result_dict, column: str = "results"):
column_name,
"Version",
"Filter",
"n-shot",
"Metric",
"Value",
"",
......@@ -295,6 +296,7 @@ def make_table(result_dict, column: str = "results"):
column_name,
"Version",
"Filter",
"n-shot",
"Metric",
"Value",
"",
......@@ -305,6 +307,7 @@ def make_table(result_dict, column: str = "results"):
for k, dic in result_dict[column].items():
version = result_dict["versions"][k]
n = str(result_dict["n-shot"][k])
for (mf), v in dic.items():
m, _, f = mf.partition(",")
if m.endswith("_stderr"):
......@@ -312,9 +315,9 @@ def make_table(result_dict, column: str = "results"):
if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f]
values.append([k, version, f, m, "%.4f" % v, "±", "%.4f" % se])
values.append([k, version, f, n, m, "%.4f" % v, "±", "%.4f" % se])
else:
values.append([k, version, f, m, "%.4f" % v, "", ""])
values.append([k, version, f, n, m, "%.4f" % v, "", ""])
k = ""
version = ""
md_writer.value_matrix = values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment