Commit 81495b0a authored by lintangsutawika's avatar lintangsutawika
Browse files

num fewshot is printed in table

parent 8ffd2630
...@@ -217,6 +217,8 @@ def evaluate( ...@@ -217,6 +217,8 @@ def evaluate(
# store the ordering of tasks and groups # store the ordering of tasks and groups
task_order = collections.defaultdict(int) task_order = collections.defaultdict(int)
task_group_alias = collections.defaultdict(dict) task_group_alias = collections.defaultdict(dict)
# store num-fewshot value per task
num_fewshot = collections.defaultdict(int)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
...@@ -234,6 +236,12 @@ def evaluate( ...@@ -234,6 +236,12 @@ def evaluate(
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config()) configs[task_name] = dict(task.dump_config())
if "num_fewshot" in configs[task_name]:
n_shot = configs[task_name]["num_fewshot"]
else:
n_shot = -1
num_fewshot[task_name] = n_shot
if "task_alias" in configs[task_name]: if "task_alias" in configs[task_name]:
task_group_alias[task_name] = configs[task_name]["task_alias"] task_group_alias[task_name] = configs[task_name]["task_alias"]
...@@ -564,6 +572,7 @@ def evaluate( ...@@ -564,6 +572,7 @@ def evaluate(
_results_agg = collections.defaultdict(dict) _results_agg = collections.defaultdict(dict)
_versions = collections.defaultdict(dict) _versions = collections.defaultdict(dict)
_num_fewshot = collections.defaultdict(int)
for task in results_agg: for task in results_agg:
task_results = results_agg[task] task_results = results_agg[task]
...@@ -579,11 +588,14 @@ def evaluate( ...@@ -579,11 +588,14 @@ def evaluate(
task_alias = task_group_alias[task] task_alias = task_group_alias[task]
_results_agg[tab_string + task_alias] = task_results _results_agg[tab_string + task_alias] = task_results
_versions[tab_string + task_alias] = versions[task] _versions[tab_string + task_alias] = versions[task]
_num_fewshot[tab_string + task_alias] = num_fewshot[task]
else: else:
_results_agg[tab_string + task] = task_results _results_agg[tab_string + task] = task_results
_versions[tab_string + task] = versions[task] _versions[tab_string + task] = versions[task]
_num_fewshot[tab_string + task] = num_fewshot[task]
results_agg = _results_agg results_agg = _results_agg
versions = _versions versions = _versions
num_fewshot = _num_fewshot
_groups_agg = collections.defaultdict(dict) _groups_agg = collections.defaultdict(dict)
for group in groups_agg: for group in groups_agg:
...@@ -609,6 +621,7 @@ def evaluate( ...@@ -609,6 +621,7 @@ def evaluate(
**({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}), **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())),
} }
if log_samples: if log_samples:
results_dict["samples"] = dict(samples) results_dict["samples"] = dict(samples)
......
...@@ -286,6 +286,7 @@ def make_table(result_dict, column: str = "results"): ...@@ -286,6 +286,7 @@ def make_table(result_dict, column: str = "results"):
column_name, column_name,
"Version", "Version",
"Filter", "Filter",
"n-shot",
"Metric", "Metric",
"Value", "Value",
"", "",
...@@ -295,6 +296,7 @@ def make_table(result_dict, column: str = "results"): ...@@ -295,6 +296,7 @@ def make_table(result_dict, column: str = "results"):
column_name, column_name,
"Version", "Version",
"Filter", "Filter",
"n-shot",
"Metric", "Metric",
"Value", "Value",
"", "",
...@@ -305,6 +307,7 @@ def make_table(result_dict, column: str = "results"): ...@@ -305,6 +307,7 @@ def make_table(result_dict, column: str = "results"):
for k, dic in result_dict[column].items(): for k, dic in result_dict[column].items():
version = result_dict["versions"][k] version = result_dict["versions"][k]
n = str(result_dict["n-shot"][k])
for (mf), v in dic.items(): for (mf), v in dic.items():
m, _, f = mf.partition(",") m, _, f = mf.partition(",")
if m.endswith("_stderr"): if m.endswith("_stderr"):
...@@ -312,9 +315,9 @@ def make_table(result_dict, column: str = "results"): ...@@ -312,9 +315,9 @@ def make_table(result_dict, column: str = "results"):
if m + "_stderr" + "," + f in dic: if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f] se = dic[m + "_stderr" + "," + f]
values.append([k, version, f, m, "%.4f" % v, "±", "%.4f" % se]) values.append([k, version, f, n, m, "%.4f" % v, "±", "%.4f" % se])
else: else:
values.append([k, version, f, m, "%.4f" % v, "", ""]) values.append([k, version, f, n, m, "%.4f" % v, "", ""])
k = "" k = ""
version = "" version = ""
md_writer.value_matrix = values md_writer.value_matrix = values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment