gpt2.py 10.4 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
3
import torch.nn as nn
Jason Phang's avatar
Jason Phang committed
4
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
5
6
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
7
from tqdm import tqdm
Jason Phang's avatar
Jason Phang committed
8
import numpy as np
Jason Phang's avatar
gpt3  
Jason Phang committed
9
10
11


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
12
13
    MAX_GEN_TOKS = 256

14
    def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1):
Leo Gao's avatar
Leo Gao committed
15
        super().__init__()
16
17
18
19
20

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
        assert isinstance(batch_size, int)

21
22
23
        if device:            
            if device not in ["cuda", "cpu"]:
                device = int(device)
Leo Gao's avatar
Leo Gao committed
24
            self.device = torch.device(device)
25
            print(f"Using device '{device}'")
Leo Gao's avatar
Leo Gao committed
26
        else:
27
28
            print("Device not specificed")
            print(f"Cuda Available? {torch.cuda.is_available()}")
Leo Gao's avatar
Leo Gao committed
29
            self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
30
31
32

        # TODO: update this to be less of a hack once subfolder is fixed in HF
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained, revision=revision +("/" + subfolder if subfolder is not None else "")).to(self.device)
Leo Gao's avatar
Leo Gao committed
33
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
34
35

        # pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
36
37
38
39
40
41
42
43
44
45
46
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder)

        assert isinstance(self.tokenizer, (
            transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
            transformers.T5Tokenizer, transformers.T5TokenizerFast,
        )), "this tokenizer has not been checked for compatibility yet!"

        self.VOCAB_SIZE = self.tokenizer.vocab_size
        self.EOT_TOKEN_ID = self.tokenizer.eos_token_id
        print(self.EOT_TOKEN_ID)

Leo Gao's avatar
Leo Gao committed
47
48
49
50
51
        try:
            self.max_length = self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparantly
            self.max_length = self.gpt2.config.max_position_embeddings
Jason Phang's avatar
Jason Phang committed
52

53
54
        if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)): 
            assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
Leo Gao's avatar
Leo Gao committed
55

56
57
        # multithreading and batching
        gpus = torch.cuda.device_count()
Leo Gao's avatar
Leo Gao committed
58
        batch_size_per_gpu = batch_size # todo: adaptive batch size
59

60
61
        # TODO: fix multi-gpu
        self.batch_size = batch_size_per_gpu# * gpus
62

Leo Gao's avatar
Leo Gao committed
63
64
65
        # TODO: fix multi-gpu
        # if gpus > 1:
        #     self.gpt2 = nn.DataParallel(self.gpt2)
66

Jason Phang's avatar
Jason Phang committed
67
    @classmethod
68
    def create_from_arg_string(cls, arg_string, additional_config={}):
Jason Phang's avatar
Jason Phang committed
69
        args = utils.simple_parse_args_string(arg_string)
70
71
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
Jason Phang committed
72

Leo Gao's avatar
Leo Gao committed
73
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
74
75
76
77
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
Jason Phang's avatar
Jason Phang committed
78
                context_enc = [self.EOT_TOKEN_ID]
Leo Gao's avatar
Leo Gao committed
79
            else:
80
                context_enc = self.tokenizer.encode(context, add_special_tokens=False)
Leo Gao's avatar
Leo Gao committed
81

82
            continuation_enc = self.tokenizer.encode(continuation, add_special_tokens=False)
Leo Gao's avatar
Leo Gao committed
83

Leo Gao's avatar
Leo Gao committed
84
            new_reqs.append(((context, continuation), context_enc, continuation_enc))
Leo Gao's avatar
Leo Gao committed
85
86
87

        return self._loglikelihood_tokens(new_reqs)

Leo Gao's avatar
Leo Gao committed
88
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
89
90
91
92
93
94
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
        with torch.no_grad():
            for string, in tqdm(requests):
Leo Gao's avatar
Leo Gao committed
95
                rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
96
                    token_list=self.tokenizer.encode(string, add_special_tokens=False),
Jason Phang's avatar
Jason Phang committed
97
98
99
                    prefix_token=self.EOT_TOKEN_ID,
                    max_seq_len=self.max_length,
                    context_len=1,
Leo Gao's avatar
Leo Gao committed
100
101
102
103
                )))

                rolling_token_windows = [(None,) + x for x in rolling_token_windows]

Leo Gao's avatar
Leo Gao committed
104
                # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
Leo Gao's avatar
Leo Gao committed
105
                string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
Leo Gao's avatar
Leo Gao committed
106
107
108
109
110
111
                
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]
                
                string_nll = sum(string_nll)
                loglikelihoods.append(string_nll)
Jason Phang's avatar
Jason Phang committed
112
113
114

        return loglikelihoods

Leo Gao's avatar
Leo Gao committed
115
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
Leo Gao's avatar
Leo Gao committed
116
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
117
        res = []
118
        with torch.no_grad():
119
120

            def _collate(x):
121
122
123
124
125
126
                # the negative sign on len(toks) sorts descending - this has a few advantages:
                # - time estimates will always be over not underestimates, which is more useful for planning
                # - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
                #   this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
                # - any OOMs will happen right away rather than near the end

Leo Gao's avatar
Leo Gao committed
127
                toks = x[1] + x[2]
128
                return (-len(toks), tuple(toks))
129
            
130
            # TODO: automatic (variable) batch size detection for vectorization
131
            reord = utils.Reorderer(requests, _collate)
Leo Gao's avatar
Leo Gao committed
132
            for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
133
                inps = []
Leo Gao's avatar
Leo Gao committed
134
                contlens = []
135
                inplens = []
136
137

                padding_length = None
Leo Gao's avatar
Leo Gao committed
138
139
140
141
142

                # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
                # tensors, then we pack them together into a batch, call the model, and then pick it all apart
                # again because vectorizing is annoying

143
                for _, context_enc, continuation_enc in chunk:
Leo Gao's avatar
Leo Gao committed
144
145
146
147
148
149
150
151
152
153
154
155
                    # sanity check
                    assert len(context_enc) > 0
                    assert len(continuation_enc) > 0
                    assert len(continuation_enc) <= self.max_length

                    # how this all works:
                    #          CTX      CONT
                    # inp    0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
                    # gpt2    \               \
                    # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.VOCAB_SIZE] slice
                    # cont_toks      4 5 6 7 8 9

156
                    # when too long to fit in context, truncate from the left
Leo Gao's avatar
Leo Gao committed
157
158
159
                    inp = torch.tensor(
                        (context_enc + continuation_enc)[-(self.max_length+1):][:-1]
                    , dtype=torch.long).to(self.device)
160
161
                    inplen, = inp.shape

Leo Gao's avatar
Leo Gao committed
162
163
                    cont = continuation_enc

164
165
166
167
168
169
                    # since in _collate we make sure length is descending, the longest is always the first one.
                    padding_length = padding_length if padding_length is not None else inplen

                    # pad to length
                    inp = torch.cat([
                        inp, # [seq]
Leo Gao's avatar
Leo Gao committed
170
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
171
172
                    ], dim=0)

Leo Gao's avatar
Leo Gao committed
173
                    inps.append(inp.unsqueeze(0))
Leo Gao's avatar
Leo Gao committed
174
                    contlens.append(cont)
175
                    inplens.append(inplen)
176

Leo Gao's avatar
Leo Gao committed
177
                multi_logits = F.log_softmax(self._model_call(torch.cat(inps, dim=0)), dim=-1).cpu()  # [batch, seq, vocab]
Leo Gao's avatar
Leo Gao committed
178
179
180

                for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
                    contlen = len(cont_toks)
181

Leo Gao's avatar
Leo Gao committed
182
                    logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
Leo Gao's avatar
Leo Gao committed
183

184
                    greedy_tokens = logits.argmax(dim=-1)
185

Leo Gao's avatar
Leo Gao committed
186
187
                    # cont_toks :: [1, seq]
                    cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)
Leo Gao's avatar
Leo Gao committed
188

189
                    max_equal = (greedy_tokens == cont_toks).all()
190

Leo Gao's avatar
Leo Gao committed
191
                    #last_token_slice = logits[:, -1, :].squeeze(0).tolist()
192

Leo Gao's avatar
Leo Gao committed
193
                    logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
194

195
                    answer = (float(logits.sum()), bool(max_equal))
Leo Gao's avatar
Leo Gao committed
196

197
198
199
                    # partial caching
                    if cache_key is not None:
                        self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Leo Gao's avatar
Leo Gao committed
200

201
                    res.append(answer)
Leo Gao's avatar
Leo Gao committed
202

Leo Gao's avatar
Leo Gao committed
203
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
204
    
Leo Gao's avatar
Leo Gao committed
205
206
207
208
209
210
211
212
213
214
    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
        logits retuned from the model
        """
        return self.gpt2(inps)[0][:, :, :50257]
    
Leo Gao's avatar
Update  
Leo Gao committed
215
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
216
217
218
219
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

220
        def _collate(x):
221
            toks = self.tokenizer.encode(x[0], add_special_tokens=False)
222
223
224
225
226
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
Leo Gao's avatar
Leo Gao committed
227
228
            if isinstance(until, str): until = [until]

229
            context_enc = torch.tensor([self.tokenizer.encode(context, add_special_tokens=False)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
230

231
            primary_until, = self.tokenizer.encode(until[0], add_special_tokens=False)
Leo Gao's avatar
Leo Gao committed
232
233
234

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
235
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
236
237
238
239
240
241
242
243
244
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
Leo Gao's avatar
Leo Gao committed
245
246
247
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
            
Leo Gao's avatar
Leo Gao committed
248
249
            res.append(s)
        
250
        return reord.get_original(res)