gpt2.py 9.03 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import transformers
import torch
3
import torch.nn as nn
Jason Phang's avatar
Jason Phang committed
4
import torch.nn.functional as F
5
from lm_eval.base import LM, TokenizedLM
Jason Phang's avatar
lib  
Jason Phang committed
6
from lm_eval import utils
Leo Gao's avatar
Update  
Leo Gao committed
7
from tqdm import tqdm
Jason Phang's avatar
Jason Phang committed
8
import numpy as np
9
10
from abc import ABC, abstractmethod
from typing import Iterable
Jason Phang's avatar
gpt3  
Jason Phang committed
11
12


13
14
15
16
class TorchLM(TokenizedLM):
    @abstractmethod
    def _model_generate(self, context, max_length, eos_token_id):
        pass
Leo Gao's avatar
Leo Gao committed
17

18
19
20
21
22
    @abstractmethod
    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Leo Gao's avatar
Leo Gao committed
23

24
25
26
27
        returns: a torch tensor of shape [batch, sequence, vocab] with the
        logits retuned from the model
        """
        pass
Jason Phang's avatar
Jason Phang committed
28

29
30
    # subclass must implement properties batch_size, vocab_size, eot_token_id, max_gen_toks, device.
    # TODO: enforce this somehow
Jason Phang's avatar
Jason Phang committed
31

Leo Gao's avatar
Leo Gao committed
32
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
Leo Gao's avatar
Leo Gao committed
33
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
Leo Gao's avatar
Leo Gao committed
34
        res = []
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
            #   this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
            return (-len(toks), tuple(toks))
        
        # TODO: automatic (variable) batch size detection for vectorization
        reord = utils.Reorderer(requests, _collate)
        for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
            inps = []
            contlens = []
            inplens = []
52

53
            padding_length = None
Leo Gao's avatar
Leo Gao committed
54

55
56
57
            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying
Leo Gao's avatar
Leo Gao committed
58

59
60
61
62
63
            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length
Leo Gao's avatar
Leo Gao committed
64

65
66
67
68
69
70
                # how this all works:
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
                # gpt2    \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.vocab_size] slice
                # cont_toks      4 5 6 7 8 9
Leo Gao's avatar
Leo Gao committed
71

72
73
74
75
76
                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
                    (context_enc + continuation_enc)[-(self.max_length+1):][:-1]
                , dtype=torch.long).to(self.device)
                inplen, = inp.shape
77

78
                cont = continuation_enc
Leo Gao's avatar
Leo Gao committed
79

80
81
                # since in _collate we make sure length is descending, the longest is always the first one.
                padding_length = padding_length if padding_length is not None else inplen
82

83
84
85
86
87
                # pad to length
                inp = torch.cat([
                    inp, # [seq]
                    torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
                ], dim=0)
88

89
90
91
                inps.append(inp.unsqueeze(0))
                contlens.append(cont)
                inplens.append(inplen)
92

93
            multi_logits = F.log_softmax(self._model_call(torch.cat(inps, dim=0)), dim=-1).cpu()  # [batch, seq, vocab]
Leo Gao's avatar
Leo Gao committed
94

95
96
            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
                contlen = len(cont_toks)
97

98
                logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
Leo Gao's avatar
Leo Gao committed
99

100
                greedy_tokens = logits.argmax(dim=-1)
101

102
103
                # cont_toks :: [1, seq]
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)
Leo Gao's avatar
Leo Gao committed
104

105
                max_equal = (greedy_tokens == cont_toks).all()
106

107
                #last_token_slice = logits[:, -1, :].squeeze(0).tolist()
108

109
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
110

111
                answer = (float(logits.sum()), bool(max_equal))
Leo Gao's avatar
Leo Gao committed
112

113
114
115
                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Leo Gao's avatar
Leo Gao committed
116

117
                res.append(answer)
Leo Gao's avatar
Leo Gao committed
118

Leo Gao's avatar
Leo Gao committed
119
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
120
    
Leo Gao's avatar
Update  
Leo Gao committed
121
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
122
123
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly
124
125

        # TODO: extract to TokenizedLM?
Leo Gao's avatar
Leo Gao committed
126
127
        res = []

128
        def _collate(x):
129
            toks = self.tok_encode(x[0])
130
131
132
133
134
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
Leo Gao's avatar
Leo Gao committed
135
136
            if isinstance(until, str): until = [until]

137
138
139
            primary_until, = self.tok_encode(until[0])
            
            context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device)
Leo Gao's avatar
Leo Gao committed
140

141
            cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until)
Leo Gao's avatar
Leo Gao committed
142

143
            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:])
Leo Gao's avatar
Leo Gao committed
144
145
146
147

            for term in until:
                s = s.split(term)[0]
            
Leo Gao's avatar
Leo Gao committed
148
149
150
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
            
Leo Gao's avatar
Leo Gao committed
151
152
            res.append(s)
        
153
        return reord.get_original(res)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233


class HFLM(TorchLM):

    def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1):
        super().__init__()

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
        assert isinstance(batch_size, int)

        if device:
            self.device = torch.device(device)
        else:
            self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

        # TODO: update this to be less of a hack once subfolder is fixed in HF
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained, revision=revision +("/" + subfolder if subfolder is not None else "")).to(self.device)
        self.gpt2.eval()

        # pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder)

        assert isinstance(self.tokenizer, (
            transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
            transformers.T5Tokenizer, transformers.T5TokenizerFast,
        )), "this tokenizer has not been checked for compatibility yet!"

        self.vocab_size = self.tokenizer.vocab_size
        self.eot_token_id = self.tokenizer.eos_token_id # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        self.max_gen_toks = 256

        try:
            self.max_length = self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparantly
            self.max_length = self.gpt2.config.max_position_embeddings

        if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
            assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], self.tokenizer.encode('hello\n\nhello')

        # multithreading and batching
        gpus = torch.cuda.device_count()
        batch_size_per_gpu = batch_size # todo: adaptive batch size

        # TODO: fix multi-gpu
        self.batch_size = batch_size_per_gpu# * gpus

        # TODO: fix multi-gpu
        # if gpus > 1:
        #     self.gpt2 = nn.DataParallel(self.gpt2)
    
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
    
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
        logits retuned from the model
        """
        with torch.no_grad():
            return self.gpt2(inps)[0][:, :, :50257]
    
    def _model_generate(self, context, max_length, eos_token_id):
        return self.gpt2.generate(
            context,
            max_length=max_length,
            eos_token_id=eos_token_id,
            do_sample=False
        )


# for backwards compability
GPT2LM = HFLM