"lm_eval/tasks/ai2d/utils.py" did not exist on "bc5c554d5169014b3e6f8240a20b1497836f7bde"
gpt2.py 1.09 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
import transformers
from base import LM
import torch
Leo Gao's avatar
Leo Gao committed
4
import torch.nn.functional as F
Leo Gao's avatar
Leo Gao committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18


class GPT2LM(LM):
    def __init__(self):
        self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
        self.tok = transformers.GPT2Tokenizer.from_pretrained('gpt2')
    
    def generate(self, context, until):
        context = torch.tensor([self.tok.encode(context.strip())], dtype=torch.long)
        res = self.gpt2.generate(context, eos_token_id=self.tok.encoder[until], do_sample=False, max_length=1024)

        # chop off the prompt and the final eos token
        return self.tok.decode(res[0][len(context[0]):-1]).strip()

Leo Gao's avatar
Leo Gao committed
19
    def loglikelihood(self, context, continuation):
Leo Gao's avatar
Leo Gao committed
20
21
22
23
24
25
26
27
        print('likelihood:', context, continuation)
        inp = torch.tensor([self.tok.encode(context + continuation)], dtype=torch.long)
        ctxlen = len(self.tok.encode(context.strip()))

        cont_toks = inp[:, ctxlen:] # [batch, seq]
        logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]

        return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)