"src/vscode:/vscode.git/clone" did not exist on "b3086ac2606d4b6999788f7faf06afa30406e44e"
Commit 7b649ded authored by Leo Gao's avatar Leo Gao
Browse files

Fixes to make greedy_until work

parent eb4c8407
...@@ -269,6 +269,7 @@ def perplexity(items): ...@@ -269,6 +269,7 @@ def perplexity(items):
req_ret_lens = { req_ret_lens = {
'loglikelihood': 2, 'loglikelihood': 2,
'greedy_until': None,
} }
import os import os
...@@ -335,11 +336,15 @@ class Request: ...@@ -335,11 +336,15 @@ class Request:
self.index = index self.index = index
def __iter__(self): def __iter__(self):
if req_ret_lens[self.type] is None:
raise IndexError('This request type does not return multiple arguments!')
i = 0 i = 0
for i in range(req_ret_lens[self.type]): for i in range(req_ret_lens[self.type]):
yield Request(self.type, self.args, i) yield Request(self.type, self.args, i)
def __getitem__(self, i): def __getitem__(self, i):
if req_ret_lens[self.type] is None:
raise IndexError('This request type does not return multiple arguments!')
return Request(self.type, self.args, i) return Request(self.type, self.args, i)
def __eq__(self, other): def __eq__(self, other):
......
...@@ -39,6 +39,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit): ...@@ -39,6 +39,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
) )
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): reqs = [reqs]
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
requests[req.type].append(req) requests[req.type].append(req)
......
...@@ -19,5 +19,9 @@ class DummyLM(LM): ...@@ -19,5 +19,9 @@ class DummyLM(LM):
return res return res
def greedy_until(self, requests): def greedy_until(self, requests):
# TODO: implement res = []
pass
for _ in requests:
res.append("lol")
return res
...@@ -49,5 +49,29 @@ class GPT2LM(LM): ...@@ -49,5 +49,29 @@ class GPT2LM(LM):
return res return res
def greedy_until(self, requests): def greedy_until(self, requests):
# TODO: implement # TODO: implement fully general `until` that handles untils that are
pass # multiple tokens or that span multiple tokens correctly
res = []
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = torch.tensor([self.tokenizer.encode(context)]).to(self.device)
primary_until, = self.tokenizer.encode(until[0])
cont = self.gpt2.generate(
context_enc,
max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
eos_token_id=primary_until,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])
for term in until:
s = s.split(term)[0]
res.append(s)
return res
...@@ -63,7 +63,7 @@ class Arithmetic(Task): ...@@ -63,7 +63,7 @@ class Arithmetic(Task):
return is_prediction return is_prediction
def process_results(self, doc, results): def process_results(self, doc, results):
ll, is_prediction = results is_prediction, = results
return { return {
"acc": is_prediction "acc": is_prediction
} }
......
...@@ -26,7 +26,7 @@ class SQuAD(HFTask): ...@@ -26,7 +26,7 @@ class SQuAD(HFTask):
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Q: ' + doc['question'] + '\n\n' + 'A:' return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
def doc_to_target(self, doc): def doc_to_target(self, doc):
answer_list = doc['answers']['text'] answer_list = doc['answers']['text']
...@@ -62,9 +62,11 @@ class SQuAD(HFTask): ...@@ -62,9 +62,11 @@ class SQuAD(HFTask):
""" """
squad_metric = datasets.load_metric("squad_v2") squad_metric = datasets.load_metric("squad_v2")
continuation, = results
predictions = { predictions = {
'id': doc['id'], 'id': doc['id'],
'prediction_text': results[0], 'prediction_text': continuation,
} }
references = { references = {
......
...@@ -8,7 +8,7 @@ import pytest ...@@ -8,7 +8,7 @@ import pytest
# TODO: more fine grained unit tests rather than this big honking integration # TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces # test once we break evaluator into smaller, more manageable pieces
@pytest.mark.parametrize("taskname,Task", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,Task", [('squad', tasks.squad.SQuAD)])
def test_evaluator(taskname, Task): def test_evaluator(taskname, Task):
task_dict = tasks.get_task_dict([taskname]) task_dict = tasks.get_task_dict([taskname])
lm = models.get_model('dummy')() lm = models.get_model('dummy')()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment