Unverified Commit 87d93e99 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

use prediction and prediction index

parent 3aeb95b6
......@@ -889,22 +889,26 @@ class ConfigurableTask(Task):
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
pred_idx = np.argmax(lls)
pred_idx_norm = np.argmax(lls / completion_len)
if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len)
pred = pred_idx
pred_norm = pred_idx_norm
else:
gold = self.doc_to_target(doc)
pred = choices[np.argmax(lls)]
pred_norm = choices[np.argmax(lls / completion_len)]
gold_idx = choices.index(gold)
pred = choices[pred_idx]
pred_norm = choices[pred_idx_norm]
acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"f1": (gold_idx, pred_idx)} if "f1" in use_metric else {}),
**({"mcc": (gold_idx, pred_idx)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment