"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "5ba79ee039f8bef8ecbc98d92b88acd1c9a5e90e"
Commit c971fa82 authored by Leo Gao's avatar Leo Gao
Browse files

Fix stuff and make tests pass

parent 0966e7b6
...@@ -60,12 +60,16 @@ class GPT2LM(LM): ...@@ -60,12 +60,16 @@ class GPT2LM(LM):
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all() max_equal = (greedy_tokens == cont_toks).all()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq] last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
res.append((float(logits.sum()), bool(max_equal))) res.append((float(logits[:, :-1].sum() if logits.shape[-1] > 1 else 0), last_token_slice, bool(max_equal)))
return reord.get_original(res) # optimization: if two requests have everything the same except the last token, use
# last token distribution to save compute
lasttoks = [self.tokenizer.encode(x[1])[-1] for x in requests]
return [(l + lts[lasttok], m) for (l, lts, m), lasttok in zip(reord.get_original(res), lasttoks)]
def greedy_until(self, requests): def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are # TODO: implement fully general `until` that handles untils that are
......
...@@ -334,7 +334,7 @@ class MRPC(HFTask): ...@@ -334,7 +334,7 @@ class MRPC(HFTask):
return True return True
def has_test_docs(self): def has_test_docs(self):
return True return False
def fewshot_description(self): def fewshot_description(self):
return "Indicate if both sentences mean the same thing." return "Indicate if both sentences mean the same thing."
...@@ -386,7 +386,7 @@ class QQP(HFTask): ...@@ -386,7 +386,7 @@ class QQP(HFTask):
return True return True
def has_test_docs(self): def has_test_docs(self):
return True return False
def fewshot_description(self): def fewshot_description(self):
return "Indicate if both questions ask the same thing." return "Indicate if both questions ask the same thing."
......
...@@ -29,4 +29,4 @@ def test_evaluator(taskname, Task): ...@@ -29,4 +29,4 @@ def test_evaluator(taskname, Task):
lm.loglikelihood = ll_fn lm.loglikelihood = ll_fn
evaluator.evaluate(lm, task_dict, False, 0, 10) evaluator.evaluate(lm, task_dict, False, 0, 3)
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment