Commit 8fffd927 authored by Leo Gao's avatar Leo Gao
Browse files

Implement gpt2 loglikelihood

parent 31696910
import transformers import transformers
from base import LM from base import LM
import torch import torch
import torch.nn.functional as F
class GPT2LM(LM): class GPT2LM(LM):
...@@ -16,4 +17,11 @@ class GPT2LM(LM): ...@@ -16,4 +17,11 @@ class GPT2LM(LM):
return self.tok.decode(res[0][len(context[0]):-1]).strip() return self.tok.decode(res[0][len(context[0]):-1]).strip()
def loglikelihood(self, context, continuation): def loglikelihood(self, context, continuation):
pass print('likelihood:', context, continuation)
inp = torch.tensor([self.tok.encode(context + continuation)], dtype=torch.long)
ctxlen = len(self.tok.encode(context.strip()))
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment