gpt2.py 2.67 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9


class GPT2LM(LM):
Jason Phang's avatar
Jason Phang committed
10
    def __init__(self, device="cpu"):
Jason Phang's avatar
Jason Phang committed
11
12
        self.device = torch.device(device)
        self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
Leo Gao's avatar
Leo Gao committed
13
        self.gpt2.eval()
14
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
15
        self.tokenizer.pad_token = "<|endoftext|>"
Jason Phang's avatar
Jason Phang committed
16
17

    @classmethod
Jason Phang's avatar
Jason Phang committed
18
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
19
20
21
        args = utils.simple_parse_args_string(arg_string)
        return cls(device=args.get("device", "cpu"))

Leo Gao's avatar
Leo Gao committed
22
23
24
    def loglikelihood(self, requests):
        res = []
        # TODO: vectorize properly
Leo Gao's avatar
Update  
Leo Gao committed
25
        for context, continuation in tqdm(requests):
Leo Gao's avatar
Leo Gao committed
26
            # when too long to fit in context, truncate from the left
Leo Gao's avatar
Leo Gao committed
27
28
29
30
31
32
33
            
            if context == "":
                # end of text as context
                context_enc = [50256]
            else:
                context_enc = self.tokenizer.encode(context)
            
Leo Gao's avatar
Leo Gao committed
34
35
36
            continuation_enc = self.tokenizer.encode(continuation)
            inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
            ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
Jason Phang's avatar
Jason Phang committed
37

Leo Gao's avatar
Leo Gao committed
38
39
            cont_toks = inp[:, ctxlen:]  # [batch, seq]
            logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]
Leo Gao's avatar
Leo Gao committed
40
41
42
43
            
            greedy_tokens = logits.argmax(dim=-1)
            max_equal = (greedy_tokens == cont_toks).all()

44
            logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
Jason Phang's avatar
Jason Phang committed
45

Leo Gao's avatar
Leo Gao committed
46
47

            res.append((float(logits.sum()), bool(max_equal)))
Leo Gao's avatar
Leo Gao committed
48
49
50

        return res
    
Leo Gao's avatar
Update  
Leo Gao committed
51
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

        for context, until in tqdm(requests):
            if isinstance(until, str): until = [until]

            context_enc = torch.tensor([self.tokenizer.encode(context)]).to(self.device)

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
            res.append(s)
        
        return res