"include/ck/utility/multi_index.hpp" did not exist on "30072aec376a2fc6c6a465fe632bc7be050ec5ea"
Commit 049dfa34 authored by Leo Gao's avatar Leo Gao
Browse files

Add test to make sure right space conventions are used

parent aea63162
...@@ -176,7 +176,7 @@ class Task(abc.ABC): ...@@ -176,7 +176,7 @@ class Task(abc.ABC):
[self.doc_to_text(doc) + self.doc_to_target(doc) for doc in self.fewshot_examples(k=num_fewshot)] [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in self.fewshot_examples(k=num_fewshot)]
) + "\n\n" ) + "\n\n"
example = self.doc_to_text(doc).strip() example = self.doc_to_text(doc)
return description + labeled_examples + example return description + labeled_examples + example
......
...@@ -61,7 +61,7 @@ class HellaSwag(HFTask): ...@@ -61,7 +61,7 @@ class HellaSwag(HFTask):
raise ValueError( raise ValueError(
"HellaSwag from HF datasets contained an invalid answer key") "HellaSwag from HF datasets contained an invalid answer key")
target = doc['endings'][index] target = doc['endings'][index]
return self.remove_brackets(target) return " " + self.remove_brackets(target)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -75,7 +75,7 @@ class HellaSwag(HFTask): ...@@ -75,7 +75,7 @@ class HellaSwag(HFTask):
""" """
ll_answers = [] ll_answers = []
for i in range(4): for i in range(4):
continuation = self.remove_brackets(doc['endings'][i]) continuation = " " + self.remove_brackets(doc['endings'][i])
ll_answers.append(rf.loglikelihood(ctx, continuation)) ll_answers.append(rf.loglikelihood(ctx, continuation))
return ll_answers return ll_answers
......
...@@ -46,12 +46,12 @@ class PiQA(Task): ...@@ -46,12 +46,12 @@ class PiQA(Task):
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc[0]['goal'] return doc[0]['goal'] + "\n"
def doc_to_target(self, doc): def doc_to_target(self, doc):
#TODO: check if oa uses newline #TODO: check if oa uses newline
rightanswer = int(doc[1]) + 1 rightanswer = int(doc[1]) + 1
return '\n' + ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]]) return ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_1, _ = rf.loglikelihood(ctx, doc[0]['sol1']) ll_1, _ = rf.loglikelihood(ctx, doc[0]['sol1'])
......
...@@ -28,10 +28,10 @@ class BoolQ(HFTask): ...@@ -28,10 +28,10 @@ class BoolQ(HFTask):
return "Read the following passages and answer each question with a yes or a no." return "Read the following passages and answer each question with a yes or a no."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"{doc['passage']}\nquestion: {doc['question']}\nanswer: " return f"{doc['passage']}\nquestion: {doc['question']}\nanswer:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
return yesno(doc['label']) return " " + yesno(doc['label'])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -156,12 +156,12 @@ class Copa(HFTask): ...@@ -156,12 +156,12 @@ class Copa(HFTask):
"cause": "because", "cause": "because",
"effect": "therefore", "effect": "therefore",
}[doc["question"]] }[doc["question"]]
return doc["premise"].strip()[:-1] + f" {connector} " return doc["premise"].strip()[:-1] + f" {connector}"
def doc_to_target(self, doc): def doc_to_target(self, doc):
correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"] correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"]
# Connect the sentences # Connect the sentences
return self.convert_choice(correct_choice) return " " + self.convert_choice(correct_choice)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
choice1 = " " + self.convert_choice(doc["choice1"]) choice1 = " " + self.convert_choice(doc["choice1"])
......
import lm_eval.tasks as tasks import lm_eval.tasks as tasks
import lm_eval.models as models import lm_eval.models as models
import lm_eval.evaluator as evaluator import lm_eval.evaluator as evaluator
import random
import pytest import pytest
...@@ -11,4 +12,21 @@ import pytest ...@@ -11,4 +12,21 @@ import pytest
def test_evaluator(taskname, Task): def test_evaluator(taskname, Task):
task_dict = tasks.get_task_dict([taskname]) task_dict = tasks.get_task_dict([taskname])
lm = models.get_model('dummy')() lm = models.get_model('dummy')()
def ll_fn(reqs):
for ctx, cont in reqs:
# space convention
assert ctx[-1] != ' '
assert cont[0] == ' ' or ctx[-1] == '\n'
res = []
random.seed(42)
for _ in reqs:
res.append((-random.random(), False))
return res
lm.loglikelihood = ll_fn
evaluator.evaluate(lm, task_dict, False, 0, 10) evaluator.evaluate(lm, task_dict, False, 0, 10)
\ No newline at end of file
import lm_eval.tasks as tasks import lm_eval.tasks as tasks
import lm_eval.base as base import lm_eval.base as base
from unittest.mock import MagicMock
from itertools import islice from itertools import islice
import pytest import pytest
...@@ -43,6 +42,10 @@ def test_documents_and_requests(taskname, Task): ...@@ -43,6 +42,10 @@ def test_documents_and_requests(taskname, Task):
assert isinstance(txt, str) assert isinstance(txt, str)
assert isinstance(tgt, str) assert isinstance(tgt, str)
# space convention
assert txt[-1] != ' '
assert tgt[0] == ' ' or txt[-1] == '\n'
reqs = task.construct_requests(doc, txt) reqs = task.construct_requests(doc, txt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment