Commit 7b649ded authored by Leo Gao's avatar Leo Gao
Browse files

Fixes to make greedy_until work

parent eb4c8407
......@@ -269,6 +269,7 @@ def perplexity(items):
req_ret_lens = {
'loglikelihood': 2,
'greedy_until': None,
}
import os
......@@ -335,11 +336,15 @@ class Request:
self.index = index
def __iter__(self):
if req_ret_lens[self.type] is None:
raise IndexError('This request type does not return multiple arguments!')
i = 0
for i in range(req_ret_lens[self.type]):
yield Request(self.type, self.args, i)
def __getitem__(self, i):
if req_ret_lens[self.type] is None:
raise IndexError('This request type does not return multiple arguments!')
return Request(self.type, self.args, i)
def __eq__(self, other):
......
......@@ -39,6 +39,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
)
reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): reqs = [reqs]
for i, req in enumerate(reqs):
requests[req.type].append(req)
......
......@@ -19,5 +19,9 @@ class DummyLM(LM):
return res
def greedy_until(self, requests):
# TODO: implement
pass
res = []
for _ in requests:
res.append("lol")
return res
......@@ -49,5 +49,29 @@ class GPT2LM(LM):
return res
def greedy_until(self, requests):
# TODO: implement
pass
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
res = []
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = torch.tensor([self.tokenizer.encode(context)]).to(self.device)
primary_until, = self.tokenizer.encode(until[0])
cont = self.gpt2.generate(
context_enc,
max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
eos_token_id=primary_until,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])
for term in until:
s = s.split(term)[0]
res.append(s)
return res
......@@ -63,7 +63,7 @@ class Arithmetic(Task):
return is_prediction
def process_results(self, doc, results):
ll, is_prediction = results
is_prediction, = results
return {
"acc": is_prediction
}
......
......@@ -26,7 +26,7 @@ class SQuAD(HFTask):
return ""
def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Q: ' + doc['question'] + '\n\n' + 'A:'
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
def doc_to_target(self, doc):
answer_list = doc['answers']['text']
......@@ -62,9 +62,11 @@ class SQuAD(HFTask):
"""
squad_metric = datasets.load_metric("squad_v2")
continuation, = results
predictions = {
'id': doc['id'],
'prediction_text': results[0],
'prediction_text': continuation,
}
references = {
......
......@@ -8,7 +8,7 @@ import pytest
# TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces
@pytest.mark.parametrize("taskname,Task", tasks.TASK_REGISTRY.items())
@pytest.mark.parametrize("taskname,Task", [('squad', tasks.squad.SQuAD)])
def test_evaluator(taskname, Task):
task_dict = tasks.get_task_dict([taskname])
lm = models.get_model('dummy')()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment