Commit e723d3d5 authored by Leo Gao's avatar Leo Gao
Browse files

Implement isgreedy

parent 27a859e2
...@@ -31,10 +31,15 @@ class GPT2LM(LM): ...@@ -31,10 +31,15 @@ class GPT2LM(LM):
cont_toks = inp[:, ctxlen:] # [batch, seq] cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab] logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq] logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
# TODO: implement isgreedy # TODO: implement isgreedy
res.append((float(logits.sum()), False)) res.append((float(logits.sum()), bool(max_equal)))
return res return res
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment