gpt2.py 9.24 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
3
import torch.nn as nn
Jason Phang's avatar
Jason Phang committed
4
import torch.nn.functional as F
Jason Phang's avatar
lib  
Jason Phang committed
5
6
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
7
from tqdm import tqdm
Jason Phang's avatar
Jason Phang committed
8
import numpy as np
Jason Phang's avatar
gpt3  
Jason Phang committed
9
10
11


class GPT2LM(LM):
Leo Gao's avatar
Leo Gao committed
12
    MAX_GEN_TOKS = 256
Jason Phang's avatar
Jason Phang committed
13
14
    VOCAB_SIZE = 50257
    EOT_TOKEN_ID = 50256
Leo Gao's avatar
Leo Gao committed
15

Leo Gao's avatar
Leo Gao committed
16
    def __init__(self, device='cuda', pretrained='gpt2', batch_size=1):
Leo Gao's avatar
Leo Gao committed
17
        super().__init__()
Leo Gao's avatar
Leo Gao committed
18
19
20
21
        if device:
            self.device = torch.device(device)
        else:
            self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
22
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
Leo Gao's avatar
Leo Gao committed
23
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
24
25

        # pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
Leo Gao's avatar
Leo Gao committed
26
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
27
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
28
29
30
31
32
        try:
            self.max_length = self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparantly
            self.max_length = self.gpt2.config.max_position_embeddings
Jason Phang's avatar
Jason Phang committed
33

Leo Gao's avatar
Leo Gao committed
34
35
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]

36
37
        # multithreading and batching
        gpus = torch.cuda.device_count()
Leo Gao's avatar
Leo Gao committed
38
        batch_size_per_gpu = batch_size # todo: adaptive batch size
39
40
41

        self.batch_size = batch_size_per_gpu * gpus

Leo Gao's avatar
Leo Gao committed
42
43
44
        # TODO: fix multi-gpu
        # if gpus > 1:
        #     self.gpt2 = nn.DataParallel(self.gpt2)
45

Jason Phang's avatar
Jason Phang committed
46
    @classmethod
Leo Gao's avatar
Leo Gao committed
47
    def create_from_arg_string(cls, arg_string, **kwargs):
Jason Phang's avatar
Jason Phang committed
48
        args = utils.simple_parse_args_string(arg_string)
Leo Gao's avatar
Leo Gao committed
49
50
        kwargs = {k: v for k, v in kwargs.items() if v is not None}
        return cls(pretrained=args.get("pretrained", "gpt2"), **kwargs)
Jason Phang's avatar
Jason Phang committed
51

Leo Gao's avatar
Leo Gao committed
52
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
53
54
55
56
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
Jason Phang's avatar
Jason Phang committed
57
                context_enc = [self.EOT_TOKEN_ID]
Leo Gao's avatar
Leo Gao committed
58
59
60
61
62
            else:
                context_enc = self.tokenizer.encode(context)

            continuation_enc = self.tokenizer.encode(continuation)

Leo Gao's avatar
Leo Gao committed
63
            new_reqs.append(((context, continuation), context_enc, continuation_enc))
Leo Gao's avatar
Leo Gao committed
64
65
66

        return self._loglikelihood_tokens(new_reqs)

Leo Gao's avatar
Leo Gao committed
67
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
68
69
70
71
72
73
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
        with torch.no_grad():
            for string, in tqdm(requests):
Leo Gao's avatar
Leo Gao committed
74
                rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
sdtblck's avatar
sdtblck committed
75
                    token_list=self.tokenizer.encode(string),
Jason Phang's avatar
Jason Phang committed
76
77
78
                    prefix_token=self.EOT_TOKEN_ID,
                    max_seq_len=self.max_length,
                    context_len=1,
Leo Gao's avatar
Leo Gao committed
79
80
81
82
                )))

                rolling_token_windows = [(None,) + x for x in rolling_token_windows]

Leo Gao's avatar
Leo Gao committed
83
                # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
Leo Gao's avatar
Leo Gao committed
84
                string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
Leo Gao's avatar
Leo Gao committed
85
86
87
88
89
90
                
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]
                
                string_nll = sum(string_nll)
                loglikelihoods.append(string_nll)
Jason Phang's avatar
Jason Phang committed
91
92
93

        return loglikelihoods

Leo Gao's avatar
Leo Gao committed
94
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
Leo Gao's avatar
Leo Gao committed
95
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
96
        res = []
97
        with torch.no_grad():
98
99

            def _collate(x):
100
101
102
103
104
105
                # the negative sign on len(toks) sorts descending - this has a few advantages:
                # - time estimates will always be over not underestimates, which is more useful for planning
                # - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
                #   this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
                # - any OOMs will happen right away rather than near the end

Leo Gao's avatar
Leo Gao committed
106
                toks = x[1] + x[2]
107
                return (-len(toks), tuple(toks))
108
            
109
            # TODO: automatic (variable) batch size detection for vectorization
110
            reord = utils.Reorderer(requests, _collate)
Leo Gao's avatar
Leo Gao committed
111
            for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
112
                inps = []
Leo Gao's avatar
Leo Gao committed
113
                contlens = []
114
                inplens = []
115
116

                padding_length = None
Leo Gao's avatar
Leo Gao committed
117
118
119
120
121

                # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
                # tensors, then we pack them together into a batch, call the model, and then pick it all apart
                # again because vectorizing is annoying

122
                for _, context_enc, continuation_enc in chunk:
Leo Gao's avatar
Leo Gao committed
123
124
125
126
127
128
129
130
131
132
133
134
                    # sanity check
                    assert len(context_enc) > 0
                    assert len(continuation_enc) > 0
                    assert len(continuation_enc) <= self.max_length

                    # how this all works:
                    #          CTX      CONT
                    # inp    0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
                    # gpt2    \               \
                    # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.VOCAB_SIZE] slice
                    # cont_toks      4 5 6 7 8 9

135
                    # when too long to fit in context, truncate from the left
Leo Gao's avatar
Leo Gao committed
136
137
138
                    inp = torch.tensor(
                        (context_enc + continuation_enc)[-(self.max_length+1):][:-1]
                    , dtype=torch.long).to(self.device)
139
140
                    inplen, = inp.shape

Leo Gao's avatar
Leo Gao committed
141
142
                    cont = continuation_enc

143
144
145
146
147
148
                    # since in _collate we make sure length is descending, the longest is always the first one.
                    padding_length = padding_length if padding_length is not None else inplen

                    # pad to length
                    inp = torch.cat([
                        inp, # [seq]
Leo Gao's avatar
Leo Gao committed
149
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
150
151
                    ], dim=0)

Leo Gao's avatar
Leo Gao committed
152
                    inps.append(inp.unsqueeze(0))
Leo Gao's avatar
Leo Gao committed
153
                    contlens.append(cont)
154
                    inplens.append(inplen)
155

Leo Gao's avatar
Leo Gao committed
156
                multi_logits = F.log_softmax(self._model_call(torch.cat(inps, dim=0)), dim=-1).cpu()  # [batch, seq, vocab]
Leo Gao's avatar
Leo Gao committed
157
158
159

                for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
                    contlen = len(cont_toks)
160

Leo Gao's avatar
Leo Gao committed
161
                    logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
Leo Gao's avatar
Leo Gao committed
162

163
                    greedy_tokens = logits.argmax(dim=-1)
164

Leo Gao's avatar
Leo Gao committed
165
166
                    # cont_toks :: [1, seq]
                    cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)
Leo Gao's avatar
Leo Gao committed
167

168
                    max_equal = (greedy_tokens == cont_toks).all()
169

Leo Gao's avatar
Leo Gao committed
170
                    #last_token_slice = logits[:, -1, :].squeeze(0).tolist()
171

Leo Gao's avatar
Leo Gao committed
172
                    logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
173

174
                    answer = (float(logits.sum()), bool(max_equal))
Leo Gao's avatar
Leo Gao committed
175

176
177
178
                    # partial caching
                    if cache_key is not None:
                        self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Leo Gao's avatar
Leo Gao committed
179

180
                    res.append(answer)
Leo Gao's avatar
Leo Gao committed
181

Leo Gao's avatar
Leo Gao committed
182
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
183
    
Leo Gao's avatar
Leo Gao committed
184
185
186
187
188
189
190
191
192
193
    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
        logits retuned from the model
        """
        return self.gpt2(inps)[0][:, :, :50257]
    
Leo Gao's avatar
Update  
Leo Gao committed
194
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
195
196
197
198
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
        res = []

199
200
201
202
203
204
205
        def _collate(x):
            toks = self.tokenizer.encode(x[0])
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
Leo Gao's avatar
Leo Gao committed
206
207
            if isinstance(until, str): until = [until]

208
            context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
209
210
211
212
213

            primary_until, = self.tokenizer.encode(until[0])

            cont = self.gpt2.generate(
                context_enc,
Leo Gao's avatar
Leo Gao committed
214
                max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
Leo Gao's avatar
Leo Gao committed
215
216
217
218
219
220
221
222
223
                eos_token_id=primary_until,
                do_sample=False
            )

            s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
Leo Gao's avatar
Leo Gao committed
224
225
226
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
            
Leo Gao's avatar
Leo Gao committed
227
228
            res.append(s)
        
229
        return reord.get_original(res)