Commit c37ad6ec authored by lintangsutawika's avatar lintangsutawika
Browse files

process multiple_output

parent 4c363ef8
...@@ -919,9 +919,6 @@ class ConfigurableTask(Task): ...@@ -919,9 +919,6 @@ class ConfigurableTask(Task):
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices]) completion_len = np.array([float(len(i)) for i in choices])
if self.multiple_output:
pass
if ( if (
2 * len(choices) == len(lls) 2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys() and "acc_mutual_info" in self._metric_fn_list.keys()
...@@ -946,14 +943,25 @@ class ConfigurableTask(Task): ...@@ -946,14 +943,25 @@ class ConfigurableTask(Task):
pred = choices[pred_idx] pred = choices[pred_idx]
pred_norm = choices[pred_idx_norm] pred_norm = choices[pred_idx_norm]
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
if type(gold) == int: if self.multiple_output:
gold_idx = gold if type(gold[0]) == int:
gold = choices[gold_idx] gold_idx = gold
elif type(gold) == str: gold = [choices[idx] for idx in gold_idx]
gold_idx = choices.index(gold) elif type(gold[0]) == str:
gold_idx = [choices.index(g) for g in gold]
acc = 1.0 if pred == gold else 0.0 else:
acc_norm = 1.0 if pred_norm == gold else 0.0 if type(gold) == int:
gold_idx = gold
gold = choices[gold_idx]
elif type(gold) == str:
gold_idx = choices.index(gold)
if self.multiple_output:
acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0
else:
acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0
result_dict = { result_dict = {
**({"acc": acc} if "acc" in use_metric else {}), **({"acc": acc} if "acc" in use_metric else {}),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment