Commit 706cb53a authored by Leo Gao's avatar Leo Gao
Browse files

Add MultipleChoiceTask

parent 2a1d7d87
...@@ -180,6 +180,35 @@ class Task(abc.ABC): ...@@ -180,6 +180,35 @@ class Task(abc.ABC):
return description + labeled_examples + example return description + labeled_examples + example
class MultipleChoiceTask(Task):
def construct_requests(self, doc, ctx):
lls = [
rf.loglikelihood(ctx, " {}".format(choice))[0]
for choice in doc['choices']
]
return lls
def process_results(self, doc, results):
gold = doc["gold"]
acc = 1. if np.argmax(results) == gold else 0.
return {
"acc": acc
}
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
def mean(arr): def mean(arr):
return sum(arr) / len(arr) return sum(arr) / len(arr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment