Commit a76754ff authored by lintangsutawika's avatar lintangsutawika
Browse files

process brier_score

parent 98f9bac9
......@@ -972,7 +972,10 @@ class ConfigurableTask(Task):
def process_results(self, doc, results):
if callable(self.config.process_results):
return self.config.process_results(doc, results)
try:
return self.config.process_results(self, doc, results)
except:
return self.config.process_results(doc, results)
result_dict = {}
use_metric = list(self._metric_fn_list.keys())
......@@ -1060,12 +1063,15 @@ class ConfigurableTask(Task):
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold]) if gold != -100 else 0
prob_norm = [float(i)/sum(lls) for i in lls]
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**({"brier_score": (gold, prob_norm)} if "brier_score" in use_metric else {}),
}
if "acc_mutual_info" in use_metric:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment