hf_causal.py 15.9 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import torch
Xingjian Shi's avatar
Xingjian Shi committed
2
import transformers
Jason Phang's avatar
gpt3  
Jason Phang committed
3

4
import copy
5
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
6

7
8
9
import torch.nn.functional as F

from lm_eval import utils
lintangsutawika's avatar
lintangsutawika committed
10
from lm_eval.logger import eval_logger
11
12
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
13

lintangsutawika's avatar
lintangsutawika committed
14
15
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria

16
from accelerate import Accelerator
haileyschoelkopf's avatar
haileyschoelkopf committed
17
from typing import Optional, Union
18
19


20
@register_model("hf-causal")
lintangsutawika's avatar
lintangsutawika committed
21
class HFCausalLM(LM):
Fabrizio Milo's avatar
Fabrizio Milo committed
22
23
24
25
26
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
Xingjian Shi's avatar
Xingjian Shi committed
27
        low_cpu_mem_usage=None,
lintangsutawika's avatar
lintangsutawika committed
28
        dtype: Optional[Union[str, torch.dtype]] = "auto",
Fabrizio Milo's avatar
Fabrizio Milo committed
29
30
31
32
        subfolder=None,
        tokenizer=None,
        batch_size=1,
    ):
Leo Gao's avatar
Leo Gao committed
33
        super().__init__()
34
35
36
37
38

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
        assert isinstance(batch_size, int)

39
        gpus = torch.cuda.device_count()
lintangsutawika's avatar
lintangsutawika committed
40

41
        if gpus <= 1:
42
            if device:
43
44
45
                if device not in ["cuda", "cpu"]:
                    device = int(device)
                self._device = torch.device(device)
46
                eval_logger.info(f"Using device '{device}'")
47
            else:
48
49
                eval_logger.info("Device not specified")
                eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
50
51
52
53
54
                self._device = (
                    torch.device("cuda")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
55
56
            self._rank = 0
            self._world_size = 1
57

58
        else:
59
            self._device = "cpu"
60

61
62
63
        # TODO: update this to be less of a hack once subfolder is fixed in HF
        revision = revision + ("/" + subfolder if subfolder is not None else "")

64
        self.model = transformers.AutoModelForCausalLM.from_pretrained(
lintangsutawika's avatar
lintangsutawika committed
65
66
            pretrained,
            revision=revision,
haileyschoelkopf's avatar
haileyschoelkopf committed
67
68
            low_cpu_mem_usage=low_cpu_mem_usage,
            torch_dtype=utils.get_dtype(dtype),
69
        ).to(self.device)
70
        self.model.eval()
Leo Gao's avatar
Leo Gao committed
71

lintangsutawika's avatar
lintangsutawika committed
72
        eval_logger.info(self.model.dtype)
haileyschoelkopf's avatar
haileyschoelkopf committed
73

74
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
Fabrizio Milo's avatar
Fabrizio Milo committed
75
            pretrained if tokenizer is None else tokenizer,
76
            revision=revision,
Fabrizio Milo's avatar
Fabrizio Milo committed
77
        )
78

79
        self.vocab_size = self.tokenizer.vocab_size
80

81
        # multithreading and batching
82
        self.batch_size_per_gpu = batch_size  # todo: adaptive batch size
83

84
        # multigpu support with accelerate
85
        if gpus > 1:
lintangsutawika's avatar
fixes  
lintangsutawika committed
86
            accelerator = Accelerator()
87
            if gpus > accelerator.num_processes:
88
                eval_logger.warning(
lintangsutawika's avatar
fixes  
lintangsutawika committed
89
                    "WARNING: The number of total system GPUs does not match the number of spawned processes. "
90
91
                    "If you would like to use data parallelism, please launch the script "
                    "with 'accelerate launch *script*'. "
lintangsutawika's avatar
fixes  
lintangsutawika committed
92
                    f"Current run will proceed with {accelerator.num_processes} devices."
93
                )
94
95
                self._rank = accelerator.local_process_index
                self._world_size = accelerator.num_processes
lintangsutawika's avatar
lintangsutawika committed
96
97
98
99
100
101
102
103
                # manually set model to use gpu, for case where many GPUs available but
                # only seek to use one
                self._device = (
                    torch.device(f"cuda:{accelerator.local_process_index}")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
                self.model.to(self.device)
104
            else:
105
                self.model = accelerator.prepare(self.model)
106
107
                self._device = torch.device(f"cuda:{accelerator.local_process_index}")
                self.accelerator = accelerator
108

109
                if self.accelerator.is_local_main_process:
110
                    eval_logger.info(f"Using {gpus} devices with data parallelism")
111

112
113
                self._rank = self.accelerator.local_process_index
                self._world_size = self.accelerator.num_processes
114

115
116
117
118
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id
119

120
121
122
    @property
    def max_length(self):
        try:
lintangsutawika's avatar
fixes  
lintangsutawika committed
123
            if hasattr(self, "accelerator"):
124
                return self.accelerator.unwrap_model(self.model).config.n_ctx
125
            else:
126
                return self.model.config.n_ctx
127
128
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparently
lintangsutawika's avatar
fixes  
lintangsutawika committed
129
130
            if hasattr(self, "accelerator"):
                return self.accelerator.unwrap_model(
131
                    self.model
lintangsutawika's avatar
fixes  
lintangsutawika committed
132
                ).config.max_position_embeddings
133
            else:
134
                return self.model.config.max_position_embeddings
135

136
137
138
    @property
    def max_gen_toks(self):
        return 256
Leo Gao's avatar
Leo Gao committed
139

140
141
    @property
    def batch_size(self):
142
        return self.batch_size_per_gpu
Leo Gao's avatar
Leo Gao committed
143

144
145
146
    @property
    def device(self):
        return self._device
Leo Gao's avatar
Leo Gao committed
147

148
149
150
151
152
153
154
    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size
Leo Gao's avatar
Leo Gao committed
155

156
157
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
158

159
160
161
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

Leo Gao's avatar
Leo Gao committed
162
163
164
165
166
167
    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
168
        logits returned from the model
Leo Gao's avatar
Leo Gao committed
169
        """
170
        with torch.no_grad():
lintangsutawika's avatar
lintangsutawika committed
171
            return self.model(inps).logits
172

lintangsutawika's avatar
lintangsutawika committed
173
    def _model_generate(self, context, max_length, stop, **generation_kwargs):
174
        # we require users to pass do_sample=True explicitly
175
176
177
        # for non-greedy gen. This should be reevaluated when considering beam search.
        if "do_sample" not in generation_kwargs.keys():
            generation_kwargs["do_sample"] = False
lintangsutawika's avatar
lintangsutawika committed
178
179
180
181
        # build stopping criteria
        stopping_criteria = stop_sequences_criteria(
            self.tokenizer, stop, 1, context.shape[0]
        )
182
        if hasattr(self, "accelerator"):
183
            return self.accelerator.unwrap_model(self.model).generate(
184
185
                context,
                max_length=max_length,
lintangsutawika's avatar
lintangsutawika committed
186
187
188
                stopping_criteria=stopping_criteria,
                pad_token_id=self.eot_token_id,
                use_cache=True,
189
                **generation_kwargs,
190
191
            )
        else:
192
            return self.model.generate(
193
194
                context,
                max_length=max_length,
lintangsutawika's avatar
lintangsutawika committed
195
196
197
                stopping_criteria=stopping_criteria,
                pad_token_id=self.eot_token_id,
                use_cache=True,
198
                **generation_kwargs,
199
            )
200

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        loglikelihoods = []
lintangsutawika's avatar
fixes  
lintangsutawika committed
218
        for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
219
220
221
222
223
224
225
226
227
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
228
                )
229
            )
230

231
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]
232

233
234
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
235

lintangsutawika's avatar
fixes  
lintangsutawika committed
236
            pad_amnt = 0
237
            if self.world_size > 1:
lintangsutawika's avatar
fixes  
lintangsutawika committed
238
239
240
241
                # TODO: Comment on what we do here
                mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
                gathered = (
                    self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
242
243
                )

244
245
                pad_amnt = max(gathered) - gathered[self.rank]
                if pad_amnt > 0:
lintangsutawika's avatar
fixes  
lintangsutawika committed
246
                    rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
247
248
249
250
251

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows, disable_tqdm=True
            )

252
253
            if (self.world_size > 1) and (pad_amnt > 0):
                string_nll = [x[0] for x in string_nll[:-pad_amnt]]
254
255
256
257
            else:
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]

258
259
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

        return loglikelihoods

    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

        # TODO: automatic (variable) batch size detection for vectorization
        re_ord = utils.Reorderer(requests, _collate)
        for chunk in utils.chunks(
281
282
            tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
            self.batch_size,
283
        ):
lintangsutawika's avatar
fixes  
lintangsutawika committed
284

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
            inps = []
            cont_toks_list = []
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
304
                # model  \               \
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
                ).to(self.device)
                (inplen,) = inp.shape

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )

                # pad length from seq to padding_length
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )

                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
                inplens.append(inplen)

            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]

            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
            ):

                # Slice to original seq length
                contlen = len(cont_toks)
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
                max_equal = (greedy_tokens == cont_toks).all()

                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]

                # Answer: (log prob, is-exact-match)
                answer = (float(logits.sum()), bool(max_equal))

                res.append(answer)

        return re_ord.get_original(res)

    def greedy_until(self, requests):
        # TODO: implement fully general `until` that handles until that are
        #       multiple tokens or that span multiple tokens correctly

        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
            return len(toks), x[0]

        re_ord = utils.Reorderer([req.args for req in requests], _collate)

384
        for context, gen_kwargs in tqdm(re_ord.get_reordered()):
lintangsutawika's avatar
lintangsutawika committed
385
            until = None
386
            if isinstance(gen_kwargs, dict):
387
                gen_kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
388
389
390
391
392
393
394
                if "until" in gen_kwargs.keys():
                    until = gen_kwargs.pop("until")
                    if isinstance(until, str):
                        until = [gen_kwargs]
                    elif not isinstance(until, list):
                        raise ValueError(
                            f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
395
                        )
396
            else:
397
398
399
                raise ValueError(
                    f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
                )
400
401
402
403
404
405
            if not until:
                until = [self.tok_decode(self.eot_token_id)]
            if "max_gen_toks" in gen_kwargs.keys():
                max_gen_toks = gen_kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks
406

lintangsutawika's avatar
lintangsutawika committed
407
408
409
410
411
412
413
            primary_until = until[0]
            # try:
            #     (primary_until,) = self.tok_encode(until[0])
            # except Exception:
            #     # if our primary until would be multiple tokens long, we'll have errors.
            #     # TODO: handling this better will let us stop generating earlier + often.
            #     primary_until = self.eot_token_id
414
415

            context_enc = torch.tensor(
416
                [self.tok_encode(context)[max_gen_toks - self.max_length :]]
417
418
419
            ).to(self.device)

            cont = self._model_generate(
420
421
                context=context_enc,
                max_length=context_enc.shape[1] + max_gen_toks,
lintangsutawika's avatar
lintangsutawika committed
422
                stop=primary_until,
423
                **gen_kwargs,
424
425
426
427
428
429
430
431
            )

            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])

            for term in until:
                s = s.split(term)[0]

            res.append(s)
432

433
        return re_ord.get_original(res)