gpt2.py 3.14 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
10
11
    MAX_GEN_TOKS = 256

12
    def __init__(self, device="cpu", pretrained='gpt2'):
Jason Phang's avatar
Jason Phang committed
13
        self.device = torch.device(device)
14
        self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
15
        self.gpt2.eval()
16
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained)
Leo Gao's avatar
Leo Gao committed
17
        self.tokenizer.pad_token = "<|endoftext|>"
Jason Phang's avatar
Jason Phang committed
18

Leo Gao's avatar
Leo Gao committed
19
20
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]

Jason Phang's avatar
Jason Phang committed
21
    @classmethod
Jason Phang's avatar
Jason Phang committed
22
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
23
        args = utils.simple_parse_args_string(arg_string)
24
        return cls(device=args.get("device", "cpu"), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
25

Leo Gao's avatar
Leo Gao committed
26
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
27
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
28
        res = []
29
30
        with torch.no_grad():
            # TODO: vectorize properly
Leo Gao's avatar
Leo Gao committed
31
            # TODO: automatic batch size detection for vectorization
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
            for context, continuation in tqdm(requests):
                # when too long to fit in context, truncate from the left

                if context == "":
                    # end of text as context
                    context_enc = [50256]
                else:
                    context_enc = self.tokenizer.encode(context)
                
                continuation_enc = self.tokenizer.encode(continuation)
                inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)

                cont_toks = inp[:, ctxlen:]  # [batch, seq]
                logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1]  # [batch, seq, vocab]
                
                greedy_tokens = logits.argmax(dim=-1)
                max_equal = (greedy_tokens == cont_toks).all()

                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]


                res.append((float(logits.sum()), bool(max_equal)))
Leo Gao's avatar
Leo Gao committed
55
56
57

        return res
    
Leo Gao's avatar
Update  
Leo Gao committed
58
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
59
60
61
62
63
64
65
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

        for context, until in tqdm(requests):
            if isinstance(until, str): until = [until]

Leo Gao's avatar
Leo Gao committed
66
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - 1024:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
67
68
69
70
71

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
72
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
73
74
75
76
77
78
79
80
81
82
83
84
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
            res.append(s)
        
        return res