gpt2.py 11.4 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import torch
Xingjian Shi's avatar
Xingjian Shi committed
2
import transformers
Jason Phang's avatar
gpt3  
Jason Phang committed
3

4
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
5

6
7
8
import torch.nn.functional as F

from lm_eval import utils
9
from lm_eval.api.model import LM, register_model
10

11
12
from accelerate import Accelerator
from itertools import islice
13

14
@register_model("hf-causal", "gpt2")
15
class HFLM(LM):
Fabrizio Milo's avatar
Fabrizio Milo committed
16
17
18
19
20
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
Xingjian Shi's avatar
Xingjian Shi committed
21
        low_cpu_mem_usage=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
22
23
24
25
        subfolder=None,
        tokenizer=None,
        batch_size=1,
    ):
Leo Gao's avatar
Leo Gao committed
26
        super().__init__()
27
28
29
30

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
        assert isinstance(batch_size, int)
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        
        gpus = torch.cuda.device_count()
        if gpus <= 1:
            if device:  
                if device not in ["cuda", "cpu"]:
                    device = int(device)
                self._device = torch.device(device)
                print(f"Using device '{device}'")
            else:
                print("Device not specified")
                print(f"Cuda Available? {torch.cuda.is_available()}")
                self._device = (
                    torch.device("cuda")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
47

48
49
50
        else:
            self._device = 'cpu'

51
52
53
        # TODO: update this to be less of a hack once subfolder is fixed in HF
        revision = revision + ("/" + subfolder if subfolder is not None else "")

54
        self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
Xingjian Shi's avatar
Xingjian Shi committed
55
            pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
56
        ).to(self.device)
Leo Gao's avatar
Leo Gao committed
57
        self.gpt2.eval()
Leo Gao's avatar
Leo Gao committed
58

59
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
Fabrizio Milo's avatar
Fabrizio Milo committed
60
            pretrained if tokenizer is None else tokenizer,
61
            revision=revision,
Fabrizio Milo's avatar
Fabrizio Milo committed
62
        )
63

64
        self.vocab_size = self.tokenizer.vocab_size
65

66
        # multithreading and batching
67
        self.batch_size_per_gpu = batch_size  # todo: adaptive batch size
68

69
70
71
72
73
74
75
76
77
78
79
        if gpus > 1:
            accelerator = Accelerator(device_placement=False)
            self.gpt2 = accelerator.prepare(self.gpt2)
            self._device = torch.device(f"cuda:{accelerator.local_process_index}")
            self.accelerator = accelerator

            if self.accelerator.is_local_main_process:
                print(f"Using {gpus} GPUs with FullyShardedDataParalell and accelerate")

            self._rank = self.accelerator.local_process_index
            self._world_size = gpus
80

81
82
83
84
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id
85

86
87
88
89
90
91
92
    @property
    def max_length(self):
        try:
            return self.gpt2.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparently
            return self.gpt2.config.max_position_embeddings
93

94
95
96
    @property
    def max_gen_toks(self):
        return 256
Leo Gao's avatar
Leo Gao committed
97

98
99
100
101
    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus
Leo Gao's avatar
Leo Gao committed
102

103
104
105
106
    @property
    def device(self):
        # TODO: fix multi-gpu
        return self._device
107
108
109
110
111
112
113
114
    
    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size
Leo Gao's avatar
Leo Gao committed
115

116
117
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
118

119
120
121
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

Leo Gao's avatar
Leo Gao committed
122
123
124
125
126
127
    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
128
        logits returned from the model
Leo Gao's avatar
Leo Gao committed
129
        """
130
        with torch.no_grad():
131
            return self.gpt2(inps)[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
132

133
134
    def _model_generate(self, context, max_length, eos_token_id):
        return self.gpt2.generate(
135
            context, max_length=max_length, pad_token_id=eos_token_id, eos_token_id=eos_token_id, do_sample=False
136
137
        )

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
        for (string,) in tqdm([req.args for req in requests]):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
            string_nll = self._loglikelihood_tokens(
                rolling_token_windows, disable_tqdm=True
            )

            # discard is_greedy
            string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

        # TODO: automatic (variable) batch size detection for vectorization
        re_ord = utils.Reorderer(requests, _collate)
        for chunk in utils.chunks(
            tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size
        ):
            inps = []
            cont_toks_list = []
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
                # gpt2    \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
                ).to(self.device)
                (inplen,) = inp.shape

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )

                # pad length from seq to padding_length
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )

                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
                inplens.append(inplen)

            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]

            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
            ):

                # Slice to original seq length
                contlen = len(cont_toks)
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
                max_equal = (greedy_tokens == cont_toks).all()

                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]

                # Answer: (log prob, is-exact-match)
                answer = (float(logits.sum()), bool(max_equal))

                res.append(answer)

        return re_ord.get_original(res)

    def greedy_until(self, requests):
        # TODO: implement fully general `until` that handles until that are
        #       multiple tokens or that span multiple tokens correctly

        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
            return len(toks), x[0]

        re_ord = utils.Reorderer([req.args for req in requests], _collate)

        for context, until in tqdm(re_ord.get_reordered()):
            if isinstance(until, str):
                until = [until]

            (primary_until,) = self.tok_encode(until[0])

            context_enc = torch.tensor(
                [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
            ).to(self.device)

            cont = self._model_generate(
                context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
            )

            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])

            for term in until:
                s = s.split(term)[0]

            res.append(s)
326

327
        return re_ord.get_original(res)