gpt2.py 1.24 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Jason Phang's avatar
gpt3  
Jason Phang committed
6
7
8


class GPT2LM(LM):
Jason Phang's avatar
Jason Phang committed
9
    def __init__(self, device="cpu"):
Jason Phang's avatar
Jason Phang committed
10
11
        self.device = torch.device(device)
        self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
Leo Gao's avatar
Leo Gao committed
12
        self.gpt2.eval()
13
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Jason Phang's avatar
Jason Phang committed
14
15

    @classmethod
Jason Phang's avatar
Jason Phang committed
16
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
17
18
19
        args = utils.simple_parse_args_string(arg_string)
        return cls(device=args.get("device", "cpu"))

Jason Phang's avatar
Jason Phang committed
20
    def loglikelihood(self, context, continuation, truncate=True):
Leo Gao's avatar
Leo Gao committed
21
        # when too long to fit in context, truncate from the left
22
23
24
25
        context_enc = self.tokenizer.encode(context)
        continuation_enc = self.tokenizer.encode(continuation)
        inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
        ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
Jason Phang's avatar
Jason Phang committed
26

Jason Phang's avatar
Jason Phang committed
27
        cont_toks = inp[:, ctxlen:]  # [batch, seq]
Jason Phang's avatar
Jason Phang committed
28
29
30
        logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]

        return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)