Commit c2fcf688 authored by gk's avatar gk
Browse files

Add MultipleChoiceExactTask

parent 095d8406
...@@ -264,7 +264,7 @@ class BaseLM(LM): ...@@ -264,7 +264,7 @@ class BaseLM(LM):
_, context_enc, continuation_enc = re_ord.get_reordered()[0] _, context_enc, continuation_enc = re_ord.get_reordered()[0]
max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]) max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
if (self.batch_size == 'auto'): if (self.batch_size == 'auto'):
if override_bs is None: if override_bs is None:
print('Passed argument batch_size = auto. Detecting largest batch size') print('Passed argument batch_size = auto. Detecting largest batch size')
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
...@@ -734,6 +734,20 @@ class MultipleChoiceTask(Task): ...@@ -734,6 +734,20 @@ class MultipleChoiceTask(Task):
} }
class MultipleChoiceExactTask(MultipleChoiceTask):
def construct_requests(self, doc, ctx):
return rf.loglikelihood(ctx, self.doc_to_target(doc))[1]
def process_results(self, doc, results):
return {"acc": 1.0 if results[0] else 0.0}
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
class PerplexityTask(Task, abc.ABC): class PerplexityTask(Task, abc.ABC):
def should_decontaminate(self): def should_decontaminate(self):
"""Whether this task supports decontamination against model training set.""" """Whether this task supports decontamination against model training set."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment