Commit 049dfa34 authored by Leo Gao's avatar Leo Gao
Browse files

Add test to make sure right space conventions are used

parent aea63162
......@@ -176,7 +176,7 @@ class Task(abc.ABC):
[self.doc_to_text(doc) + self.doc_to_target(doc) for doc in self.fewshot_examples(k=num_fewshot)]
) + "\n\n"
example = self.doc_to_text(doc).strip()
example = self.doc_to_text(doc)
return description + labeled_examples + example
......
......@@ -61,7 +61,7 @@ class HellaSwag(HFTask):
raise ValueError(
"HellaSwag from HF datasets contained an invalid answer key")
target = doc['endings'][index]
return self.remove_brackets(target)
return " " + self.remove_brackets(target)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......@@ -75,7 +75,7 @@ class HellaSwag(HFTask):
"""
ll_answers = []
for i in range(4):
continuation = self.remove_brackets(doc['endings'][i])
continuation = " " + self.remove_brackets(doc['endings'][i])
ll_answers.append(rf.loglikelihood(ctx, continuation))
return ll_answers
......
......@@ -46,12 +46,12 @@ class PiQA(Task):
return ""
def doc_to_text(self, doc):
return doc[0]['goal']
return doc[0]['goal'] + "\n"
def doc_to_target(self, doc):
#TODO: check if oa uses newline
rightanswer = int(doc[1]) + 1
return '\n' + ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]])
return ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]])
def construct_requests(self, doc, ctx):
ll_1, _ = rf.loglikelihood(ctx, doc[0]['sol1'])
......
......@@ -28,10 +28,10 @@ class BoolQ(HFTask):
return "Read the following passages and answer each question with a yes or a no."
def doc_to_text(self, doc):
return f"{doc['passage']}\nquestion: {doc['question']}\nanswer: "
return f"{doc['passage']}\nquestion: {doc['question']}\nanswer:"
def doc_to_target(self, doc):
return yesno(doc['label'])
return " " + yesno(doc['label'])
def construct_requests(self, doc, ctx):
......@@ -156,12 +156,12 @@ class Copa(HFTask):
"cause": "because",
"effect": "therefore",
}[doc["question"]]
return doc["premise"].strip()[:-1] + f" {connector} "
return doc["premise"].strip()[:-1] + f" {connector}"
def doc_to_target(self, doc):
correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"]
# Connect the sentences
return self.convert_choice(correct_choice)
return " " + self.convert_choice(correct_choice)
def construct_requests(self, doc, ctx):
choice1 = " " + self.convert_choice(doc["choice1"])
......
import lm_eval.tasks as tasks
import lm_eval.models as models
import lm_eval.evaluator as evaluator
import random
import pytest
......@@ -11,4 +12,21 @@ import pytest
def test_evaluator(taskname, Task):
task_dict = tasks.get_task_dict([taskname])
lm = models.get_model('dummy')()
def ll_fn(reqs):
for ctx, cont in reqs:
# space convention
assert ctx[-1] != ' '
assert cont[0] == ' ' or ctx[-1] == '\n'
res = []
random.seed(42)
for _ in reqs:
res.append((-random.random(), False))
return res
lm.loglikelihood = ll_fn
evaluator.evaluate(lm, task_dict, False, 0, 10)
\ No newline at end of file
import lm_eval.tasks as tasks
import lm_eval.base as base
from unittest.mock import MagicMock
from itertools import islice
import pytest
......@@ -43,6 +42,10 @@ def test_documents_and_requests(taskname, Task):
assert isinstance(txt, str)
assert isinstance(tgt, str)
# space convention
assert txt[-1] != ' '
assert tgt[0] == ' ' or txt[-1] == '\n'
reqs = task.construct_requests(doc, txt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment