Commit 1bf3c505 authored by kabbi159's avatar kabbi159
Browse files

fix: add acc_norm in kobest_hellaswag

parent cc7650d8
...@@ -248,20 +248,28 @@ class HellaSwag(MultipleChoiceTask): ...@@ -248,20 +248,28 @@ class HellaSwag(MultipleChoiceTask):
def process_results(self, doc, results): def process_results(self, doc, results):
pred = np.argmax(results) pred = np.argmax(results)
gold = doc["gold"] gold = doc["gold"]
acc = 1. if np.argmax(results) == gold else 0.
completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
return { return {
"acc": pred == gold, "acc": acc,
"acc_norm": acc_norm,
"macro_f1": (gold, pred) "macro_f1": (gold, pred)
} }
def higher_is_better(self): def higher_is_better(self):
return { return {
"acc": True, "acc": True,
"acc_norm": True,
"macro_f1": True "macro_f1": True
} }
def aggregation(self): def aggregation(self):
return { return {
"acc": mean, "acc": mean,
"acc_norm": mean,
"macro_f1": macro_f1_score "macro_f1": macro_f1_score
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment