gpt2.py 1.15 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
import transformers
from base import LM
import torch
Leo Gao's avatar
Leo Gao committed
4
import torch.nn.functional as F
Leo Gao's avatar
Leo Gao committed
5
6
7


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
8
9
    def __init__(self, dev='cpu'):
        self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(dev)
Leo Gao's avatar
Leo Gao committed
10
        self.tok = transformers.GPT2Tokenizer.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
11
        self.dev = dev
Leo Gao's avatar
Leo Gao committed
12
13
    
    def generate(self, context, until):
Leo Gao's avatar
Leo Gao committed
14
        context = torch.tensor([self.tok.encode(context.strip())], dtype=torch.long).to(self.dev)
Leo Gao's avatar
Leo Gao committed
15
16
17
18
19
        res = self.gpt2.generate(context, eos_token_id=self.tok.encoder[until], do_sample=False, max_length=1024)

        # chop off the prompt and the final eos token
        return self.tok.decode(res[0][len(context[0]):-1]).strip()

Leo Gao's avatar
Leo Gao committed
20
    def loglikelihood(self, context, continuation):
Leo Gao's avatar
Leo Gao committed
21
        print('likelihood:', context, continuation)
Leo Gao's avatar
Leo Gao committed
22
        inp = torch.tensor([self.tok.encode(context + continuation)], dtype=torch.long).to(self.dev)
Leo Gao's avatar
Leo Gao committed
23
24
25
26
27
28
        ctxlen = len(self.tok.encode(context.strip()))

        cont_toks = inp[:, ctxlen:] # [batch, seq]
        logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]

        return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)