Commit 432bd44c authored by Leo Gao's avatar Leo Gao
Browse files

Fixes to make greedy_until work

# Conflicts:
#	lm_eval/models/gpt2.py
#	lm_eval/tasks/squad.py
parent e8f9dc71
...@@ -269,6 +269,7 @@ def perplexity(items): ...@@ -269,6 +269,7 @@ def perplexity(items):
req_ret_lens = { req_ret_lens = {
'loglikelihood': 2, 'loglikelihood': 2,
'greedy_until': None,
} }
import os import os
...@@ -335,11 +336,15 @@ class Request: ...@@ -335,11 +336,15 @@ class Request:
self.index = index self.index = index
def __iter__(self): def __iter__(self):
if req_ret_lens[self.type] is None:
raise IndexError('This request type does not return multiple arguments!')
i = 0 i = 0
for i in range(req_ret_lens[self.type]): for i in range(req_ret_lens[self.type]):
yield Request(self.type, self.args, i) yield Request(self.type, self.args, i)
def __getitem__(self, i): def __getitem__(self, i):
if req_ret_lens[self.type] is None:
raise IndexError('This request type does not return multiple arguments!')
return Request(self.type, self.args, i) return Request(self.type, self.args, i)
def __eq__(self, other): def __eq__(self, other):
......
...@@ -39,6 +39,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit): ...@@ -39,6 +39,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
) )
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): reqs = [reqs]
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
requests[req.type].append(req) requests[req.type].append(req)
......
...@@ -19,5 +19,9 @@ class DummyLM(LM): ...@@ -19,5 +19,9 @@ class DummyLM(LM):
return res return res
def greedy_until(self, requests): def greedy_until(self, requests):
# TODO: implement res = []
pass
for _ in requests:
res.append("lol")
return res
...@@ -66,7 +66,7 @@ class GPT2LM(LM): ...@@ -66,7 +66,7 @@ class GPT2LM(LM):
cont = self.gpt2.generate( cont = self.gpt2.generate(
context_enc, context_enc,
max_length=self.MAX_GEN_TOKS, max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
eos_token_id=primary_until, eos_token_id=primary_until,
do_sample=False do_sample=False
) )
......
...@@ -66,7 +66,7 @@ class Arithmetic(Task): ...@@ -66,7 +66,7 @@ class Arithmetic(Task):
return is_prediction return is_prediction
def process_results(self, doc, results): def process_results(self, doc, results):
ll, is_prediction = results is_prediction, = results
return { return {
"acc": is_prediction "acc": is_prediction
} }
......
...@@ -30,7 +30,7 @@ class SQuAD(HFTask): ...@@ -30,7 +30,7 @@ class SQuAD(HFTask):
return "Title: The_Title_of_It\n\nBackground: A text passage as background to answer the question with.\n\nQ: Question about the passage.\n\nA: Answer." return "Title: The_Title_of_It\n\nBackground: A text passage as background to answer the question with.\n\nQ: Question about the passage.\n\nA: Answer."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Q: ' + doc['question'] + '\n\n' + 'A: ' return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
def doc_to_target(self, doc): def doc_to_target(self, doc):
answer_list = doc['answers']['text'] answer_list = doc['answers']['text']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment