Commit 8fffd927 authored by Leo Gao's avatar Leo Gao
Browse files

Implement gpt2 loglikelihood

parent 31696910
import transformers
from base import LM
import torch
import torch.nn.functional as F
class GPT2LM(LM):
......@@ -16,4 +17,11 @@ class GPT2LM(LM):
return self.tok.decode(res[0][len(context[0]):-1]).strip()
def loglikelihood(self, context, continuation):
pass
print('likelihood:', context, continuation)
inp = torch.tensor([self.tok.encode(context + continuation)], dtype=torch.long)
ctxlen = len(self.tok.encode(context.strip()))
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment