Commit c37ad6ec authored by lintangsutawika's avatar lintangsutawika
Browse files

process multiple_output

parent 4c363ef8
...@@ -919,9 +919,6 @@ class ConfigurableTask(Task): ...@@ -919,9 +919,6 @@ class ConfigurableTask(Task):
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices]) completion_len = np.array([float(len(i)) for i in choices])
if self.multiple_output:
pass
if ( if (
2 * len(choices) == len(lls) 2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys() and "acc_mutual_info" in self._metric_fn_list.keys()
...@@ -946,12 +943,23 @@ class ConfigurableTask(Task): ...@@ -946,12 +943,23 @@ class ConfigurableTask(Task):
pred = choices[pred_idx] pred = choices[pred_idx]
pred_norm = choices[pred_idx_norm] pred_norm = choices[pred_idx_norm]
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
if self.multiple_output:
if type(gold[0]) == int:
gold_idx = gold
gold = [choices[idx] for idx in gold_idx]
elif type(gold[0]) == str:
gold_idx = [choices.index(g) for g in gold]
else:
if type(gold) == int: if type(gold) == int:
gold_idx = gold gold_idx = gold
gold = choices[gold_idx] gold = choices[gold_idx]
elif type(gold) == str: elif type(gold) == str:
gold_idx = choices.index(gold) gold_idx = choices.index(gold)
if self.multiple_output:
acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0
else:
acc = 1.0 if pred == gold else 0.0 acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0 acc_norm = 1.0 if pred_norm == gold else 0.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment