Unverified Commit 46c79664 authored by Alex Bäuerle's avatar Alex Bäuerle Committed by GitHub
Browse files

add length of strings and answer options to metadata (#1222)

parent 6a1c19ed
...@@ -164,6 +164,7 @@ def generate_dataset( ...@@ -164,6 +164,7 @@ def generate_dataset(
{ {
"id": ids, "id": ids,
"data": instance, "data": instance,
"input_len": [len(x) for x in instance],
"labels": labels, "labels": labels,
"output_type": config["output_type"], "output_type": config["output_type"],
} }
...@@ -181,26 +182,30 @@ def generate_system_df(data, config): ...@@ -181,26 +182,30 @@ def generate_system_df(data, config):
pd.Dataframe: A dataframe that is ready to be uploaded to Zeno as a system. pd.Dataframe: A dataframe that is ready to be uploaded to Zeno as a system.
""" """
ids = [x["doc_id"] for x in data] ids = [x["doc_id"] for x in data]
answers = [""] * len(ids) system_dict = {"id": ids}
system_dict["output"] = [""] * len(ids)
if config["output_type"] == "loglikelihood": if config["output_type"] == "loglikelihood":
answers = [ system_dict["output"] = [
"correct" if x["filtered_resps"][0][1] is True else "incorrect" "correct" if x["filtered_resps"][0][1] is True else "incorrect"
for x in data for x in data
] ]
elif config["output_type"] == "multiple_choice": elif config["output_type"] == "multiple_choice":
answers = [", ".join([str(y[0]) for y in x["filtered_resps"]]) for x in data] system_dict["output"] = [
", ".join([str(y[0]) for y in x["filtered_resps"]]) for x in data
]
system_dict["num_answers"] = [len(x["filtered_resps"]) for x in data]
elif config["output_type"] == "loglikelihood_rolling": elif config["output_type"] == "loglikelihood_rolling":
answers = [str(x["filtered_resps"][0]) for x in data] system_dict["output"] = [str(x["filtered_resps"][0]) for x in data]
elif config["output_type"] == "generate_until": elif config["output_type"] == "generate_until":
answers = [str(x["filtered_resps"][0]) for x in data] system_dict["output"] = [str(x["filtered_resps"][0]) for x in data]
system_dict["output_length"] = [len(str(x["filtered_resps"][0])) for x in data]
metrics = {} metrics = {}
for metric in config["metric_list"]: for metric in config["metric_list"]:
if "aggregation" in metric and metric["aggregation"] == "mean": if "aggregation" in metric and metric["aggregation"] == "mean":
metrics[metric["metric"]] = [x[metric["metric"]] for x in data] metrics[metric["metric"]] = [x[metric["metric"]] for x in data]
system_dict = {"id": ids, "output": answers}
system_dict.update(metrics) system_dict.update(metrics)
system_df = pd.DataFrame(system_dict) system_df = pd.DataFrame(system_dict)
return system_df return system_df
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment