Commit c971fa82 authored by Leo Gao's avatar Leo Gao
Browse files

Fix stuff and make tests pass

parent 0966e7b6
......@@ -60,12 +60,16 @@ class GPT2LM(LM):
greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
res.append((float(logits.sum()), bool(max_equal)))
res.append((float(logits[:, :-1].sum() if logits.shape[-1] > 1 else 0), last_token_slice, bool(max_equal)))
return reord.get_original(res)
# optimization: if two requests have everything the same except the last token, use
# last token distribution to save compute
lasttoks = [self.tokenizer.encode(x[1])[-1] for x in requests]
return [(l + lts[lasttok], m) for (l, lts, m), lasttok in zip(reord.get_original(res), lasttoks)]
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are
......
......@@ -334,7 +334,7 @@ class MRPC(HFTask):
return True
def has_test_docs(self):
return True
return False
def fewshot_description(self):
return "Indicate if both sentences mean the same thing."
......@@ -386,7 +386,7 @@ class QQP(HFTask):
return True
def has_test_docs(self):
return True
return False
def fewshot_description(self):
return "Indicate if both questions ask the same thing."
......
......@@ -29,4 +29,4 @@ def test_evaluator(taskname, Task):
lm.loglikelihood = ll_fn
evaluator.evaluate(lm, task_dict, False, 0, 10)
\ No newline at end of file
evaluator.evaluate(lm, task_dict, False, 0, 3)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment