gpt2.py 6.46 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
Jason Phang's avatar
Jason Phang committed
3
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
4
5
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
6
from tqdm import tqdm
Jason Phang's avatar
Jason Phang committed
7
import numpy as np
Jason Phang's avatar
gpt3  
Jason Phang committed
8
9
10


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
11
    MAX_GEN_TOKS = 256
Jason Phang's avatar
Jason Phang committed
12
13
    VOCAB_SIZE = 50257
    EOT_TOKEN_ID = 50256
Leo Gao's avatar
Leo Gao committed
14

Leo Gao's avatar
Leo Gao committed
15
    def __init__(self, device=None, pretrained='gpt2'):
Leo Gao's avatar
Leo Gao committed
16
        super().__init__()
Leo Gao's avatar
Leo Gao committed
17
18
19
20
        if device:
            self.device = torch.device(device)
        else:
            self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
21
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
22
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
23
24

        # pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
Leo Gao's avatar
Leo Gao committed
25
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
26
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
27
28
29
30
31
        try:
            self.max_length = self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparantly
            self.max_length = self.gpt2.config.max_position_embeddings
Jason Phang's avatar
Jason Phang committed
32

Leo Gao's avatar
Leo Gao committed
33
34
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]

Jason Phang's avatar
Jason Phang committed
35
    @classmethod
Jason Phang's avatar
Jason Phang committed
36
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
Jason Phang committed
37
        args = utils.simple_parse_args_string(arg_string)
Leo Gao's avatar
Leo Gao committed
38
        return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
Jason Phang's avatar
Jason Phang committed
39

Leo Gao's avatar
Leo Gao committed
40
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
41
42
43
44
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
Jason Phang's avatar
Jason Phang committed
45
                context_enc = [self.EOT_TOKEN_ID]
Leo Gao's avatar
Leo Gao committed
46
47
48
49
50
            else:
                context_enc = self.tokenizer.encode(context)

            continuation_enc = self.tokenizer.encode(continuation)

Leo Gao's avatar
Leo Gao committed
51
            new_reqs.append(((context, continuation), context_enc, continuation_enc))
Leo Gao's avatar
Leo Gao committed
52
53
54

        return self._loglikelihood_tokens(new_reqs)

Jason Phang's avatar
Jason Phang committed
55
56
57
58
59
60
61
62
    def loglikelihood_perplexity(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
        with torch.no_grad():
            for string, in tqdm(requests):
                encoded = self.tokenizer.encode_plus(string)["input_ids"]
Leo Gao's avatar
Leo Gao committed
63
                rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
Jason Phang's avatar
Jason Phang committed
64
65
66
67
                    token_list=encoded,
                    prefix_token=self.EOT_TOKEN_ID,
                    max_seq_len=self.max_length,
                    context_len=1,
Leo Gao's avatar
Leo Gao committed
68
69
70
71
72
73
74
75
76
77
78
79
                )))

                # todo: figure out partial caching
                rolling_token_windows = [(None,) + x for x in rolling_token_windows]

                string_nll = self._loglikelihood_tokens(rolling_token_windows)
                
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]
                
                string_nll = sum(string_nll)
                loglikelihoods.append(string_nll)
Jason Phang's avatar
Jason Phang committed
80
81
82

        return loglikelihoods

Leo Gao's avatar
Leo Gao committed
83
    def _loglikelihood_tokens(self, requests):
Leo Gao's avatar
Leo Gao committed
84
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
85
        res = []
86
87
        with torch.no_grad():
            # TODO: vectorize properly
Leo Gao's avatar
Leo Gao committed
88
            # TODO: automatic batch size detection for vectorization
89
90

            def _collate(x):
Leo Gao's avatar
Leo Gao committed
91
                toks = x[1] + x[2]
Leo Gao's avatar
Leo Gao committed
92
                return (len(toks), tuple(toks))
Jason Phang's avatar
Jason Phang committed
93

94
            reord = utils.Reorderer(requests, _collate)
Leo Gao's avatar
Leo Gao committed
95
            for cache_key, context_enc, continuation_enc in tqdm(reord.get_reordered()):
Leo Gao's avatar
Leo Gao committed
96
97
98
99
100
101
102
103
104
105
106
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length
                
                # how this all works:
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
                # gpt2    \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.VOCAB_SIZE] slice
                # cont_toks      4 5 6 7 8 9

107
                # when too long to fit in context, truncate from the left
Leo Gao's avatar
Leo Gao committed
108
109
110
111
112
                inp = torch.tensor([
                    (context_enc + continuation_enc)[-(self.max_length+1):] 
                ], dtype=torch.long).to(self.device)

                cont_toks = inp[:, -len(continuation_enc):]  # [batch, seq]
113

Leo Gao's avatar
Leo Gao committed
114
                logits = F.log_softmax(self.gpt2(inp[:, :-1])[0][:, -len(continuation_enc):, :self.VOCAB_SIZE], dim=-1)  # [batch, seq, vocab] - vocab size is clipped to exclude padding tokens or whatever
Leo Gao's avatar
Leo Gao committed
115

116
117
118
                greedy_tokens = logits.argmax(dim=-1)
                max_equal = (greedy_tokens == cont_toks).all()

Leo Gao's avatar
Leo Gao committed
119
                last_token_slice = logits[:, -1, :].squeeze(0).tolist()
120

Leo Gao's avatar
Leo Gao committed
121
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
122

Leo Gao's avatar
Leo Gao committed
123
                answer = (float(logits.cpu().to(torch.float64).sum()), bool(max_equal))
Leo Gao's avatar
Leo Gao committed
124
125

                # partial caching
Leo Gao's avatar
Leo Gao committed
126
127
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Leo Gao's avatar
Leo Gao committed
128
129

                res.append(answer)
Leo Gao's avatar
Leo Gao committed
130

Leo Gao's avatar
Leo Gao committed
131
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
132
    
Leo Gao's avatar
Update  
Leo Gao committed
133
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
134
135
136
137
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

138
139
140
141
142
143
144
        def _collate(x):
            toks = self.tokenizer.encode(x[0])
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
Leo Gao's avatar
Leo Gao committed
145
146
            if isinstance(until, str): until = [until]

147
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
148
149
150
151
152

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
153
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
154
155
156
157
158
159
160
161
162
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
Leo Gao's avatar
Leo Gao committed
163
164
165
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
            
Leo Gao's avatar
Leo Gao committed
166
167
            res.append(s)
        
168
        return reord.get_original(res)