Commit c37ad6ec authored by lintangsutawika's avatar lintangsutawika
Browse files

process multiple_output

parent 4c363ef8
......@@ -919,9 +919,6 @@ class ConfigurableTask(Task):
choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices])
if self.multiple_output:
pass
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys()
......@@ -946,14 +943,25 @@ class ConfigurableTask(Task):
pred = choices[pred_idx]
pred_norm = choices[pred_idx_norm]
gold = self.doc_to_target(doc)
if type(gold) == int:
gold_idx = gold
gold = choices[gold_idx]
elif type(gold) == str:
gold_idx = choices.index(gold)
acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0
if self.multiple_output:
if type(gold[0]) == int:
gold_idx = gold
gold = [choices[idx] for idx in gold_idx]
elif type(gold[0]) == str:
gold_idx = [choices.index(g) for g in gold]
else:
if type(gold) == int:
gold_idx = gold
gold = choices[gold_idx]
elif type(gold) == str:
gold_idx = choices.index(gold)
if self.multiple_output:
acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0
else:
acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment