Commit 69c83456 authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Fix `MuTual` loglikelihood request bug

parent 6fa1a4fa
...@@ -21,7 +21,7 @@ from best_download import download_file ...@@ -21,7 +21,7 @@ from best_download import download_file
class MuTualBase(Task): class MuTualBase(Task):
VERSION = 0 VERSION = 1
BASE_PATH = Path("data/mutual") BASE_PATH = Path("data/mutual")
DATASET_NAME = None DATASET_NAME = None
CHOICES = ['A', 'B', 'C', 'D'] CHOICES = ['A', 'B', 'C', 'D']
...@@ -83,7 +83,7 @@ class MuTualBase(Task): ...@@ -83,7 +83,7 @@ class MuTualBase(Task):
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
lls = [] lls = []
for option in doc["options"]: for option in doc["options"]:
lls.append(rf.loglikelihood(ctx, f" {self.detokenize(option)}")) lls.append(rf.loglikelihood(ctx, f" {self.detokenize(option)}")[0])
return lls return lls
def detokenize(self, text): def detokenize(self, text):
...@@ -100,7 +100,7 @@ class MuTualBase(Task): ...@@ -100,7 +100,7 @@ class MuTualBase(Task):
text = text.replace(" ?", "?") text = text.replace(" ?", "?")
text = text.replace(" ,", ",") text = text.replace(" ,", ",")
text = text.replace(" .", ".") text = text.replace(" .", ".")
return text return text.lower()
def process_results(self, doc, results): def process_results(self, doc, results):
gold = self.CHOICES.index(doc["answers"]) gold = self.CHOICES.index(doc["answers"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment