gpt2.py 1.51 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Jason Phang's avatar
gpt3  
Jason Phang committed
6
7
8
9
10
from . import MODEL_REGISTRY


@MODEL_REGISTRY.register("gpt2")
class GPT2LM(LM):
Jason Phang's avatar
Jason Phang committed
11
    def __init__(self, device="cpu"):
Jason Phang's avatar
Jason Phang committed
12
13
        self.device = torch.device(device)
        self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
Jason Phang's avatar
gpt3  
Jason Phang committed
14
        self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
Jason Phang's avatar
Jason Phang committed
15
16

    @classmethod
Jason Phang's avatar
Jason Phang committed
17
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
18
19
20
        args = utils.simple_parse_args_string(arg_string)
        return cls(device=args.get("device", "cpu"))

Jason Phang's avatar
gpt3  
Jason Phang committed
21
    def generate(self, context, max_gen_length):
Jason Phang's avatar
Jason Phang committed
22
        context = torch.tensor([self.tokenizer.encode(context.strip())], dtype=torch.long).to(self.device)
Jason Phang's avatar
gpt3  
Jason Phang committed
23
24
25
26
27
28
29
30
        res = self.gpt2.generate(
            context,
            eos_token_id=self.tokenizer.eos_token_id,
            do_sample=False,
            max_length=max_gen_length,
        )

        # chop off the prompt and the final eos token
Jason Phang's avatar
Jason Phang committed
31
        return self.tokenizer.decode(res[0][len(context[0]):-1]).strip()
Jason Phang's avatar
gpt3  
Jason Phang committed
32

Jason Phang's avatar
checkin  
Jason Phang committed
33
    def loglikelihood(self, context, continuation):
Jason Phang's avatar
Jason Phang committed
34
35
        inp = torch.tensor([self.tokenizer.encode(context + continuation)], dtype=torch.long).to(self.device)
        ctxlen = len(self.tokenizer.encode(context.strip()))
Jason Phang's avatar
Jason Phang committed
36
37
38
39
40

        cont_toks = inp[:, ctxlen:] # [batch, seq]
        logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]

        return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)