hf_causal.py 15.3 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import torch
Xingjian Shi's avatar
Xingjian Shi committed
2
import transformers
Jason Phang's avatar
gpt3  
Jason Phang committed
3

4
import copy
5
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
6

7
8
9
import torch.nn.functional as F

from lm_eval import utils
lintangsutawika's avatar
lintangsutawika committed
10
from lm_eval.logger import eval_logger
11
12
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
13

14
from accelerate import Accelerator
haileyschoelkopf's avatar
haileyschoelkopf committed
15
from typing import Optional, Union
16
17


18
@register_model("hf-causal")
19
class HFLM(LM):
Fabrizio Milo's avatar
Fabrizio Milo committed
20
21
22
23
24
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
Xingjian Shi's avatar
Xingjian Shi committed
25
        low_cpu_mem_usage=None,
lintangsutawika's avatar
lintangsutawika committed
26
        dtype: Optional[Union[str, torch.dtype]] = "auto",
Fabrizio Milo's avatar
Fabrizio Milo committed
27
28
29
30
        subfolder=None,
        tokenizer=None,
        batch_size=1,
    ):
Leo Gao's avatar
Leo Gao committed
31
        super().__init__()
32
33
34
35
36

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
        assert isinstance(batch_size, int)

37
38
        gpus = torch.cuda.device_count()
        if gpus <= 1:
39
            if device:
40
41
42
                if device not in ["cuda", "cpu"]:
                    device = int(device)
                self._device = torch.device(device)
43
                eval_logger.info(f"Using device '{device}'")
44
            else:
45
46
                eval_logger.info("Device not specified")
                eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
47
48
49
50
51
                self._device = (
                    torch.device("cuda")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
52
53
            self._rank = 0
            self._world_size = 1
54

55
        else:
56
            self._device = "cpu"
57

58
59
60
        # TODO: update this to be less of a hack once subfolder is fixed in HF
        revision = revision + ("/" + subfolder if subfolder is not None else "")

61
        self.model = transformers.AutoModelForCausalLM.from_pretrained(
lintangsutawika's avatar
lintangsutawika committed
62
63
            pretrained,
            revision=revision,
haileyschoelkopf's avatar
haileyschoelkopf committed
64
65
            low_cpu_mem_usage=low_cpu_mem_usage,
            torch_dtype=utils.get_dtype(dtype),
66
        ).to(self.device)
67
        self.model.eval()
Leo Gao's avatar
Leo Gao committed
68

haileyschoelkopf's avatar
haileyschoelkopf committed
69
70
        print(self.model.dtype)

71
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
Fabrizio Milo's avatar
Fabrizio Milo committed
72
            pretrained if tokenizer is None else tokenizer,
73
            revision=revision,
Fabrizio Milo's avatar
Fabrizio Milo committed
74
        )
75

76
        self.vocab_size = self.tokenizer.vocab_size
77

78
        # multithreading and batching
79
        self.batch_size_per_gpu = batch_size  # todo: adaptive batch size
80

81
        # multigpu support with accelerate
82
        if gpus > 1:
lintangsutawika's avatar
fixes  
lintangsutawika committed
83
            accelerator = Accelerator()
84
            if gpus > accelerator.num_processes:
85
                eval_logger.warning(
lintangsutawika's avatar
fixes  
lintangsutawika committed
86
                    "WARNING: The number of total system GPUs does not match the number of spawned processes. "
87
88
                    "If you would like to use data parallelism, please launch the script "
                    "with 'accelerate launch *script*'. "
lintangsutawika's avatar
fixes  
lintangsutawika committed
89
                    f"Current run will proceed with {accelerator.num_processes} devices."
90
                )
91
92
                self._rank = accelerator.local_process_index
                self._world_size = accelerator.num_processes
93
            else:
94
                self.model = accelerator.prepare(self.model)
95
96
                self._device = torch.device(f"cuda:{accelerator.local_process_index}")
                self.accelerator = accelerator
97

98
                if self.accelerator.is_local_main_process:
99
                    eval_logger.info(f"Using {gpus} devices with data parallelism")
100

101
102
                self._rank = self.accelerator.local_process_index
                self._world_size = self.accelerator.num_processes
103

104
105
106
107
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id
108

109
110
111
    @property
    def max_length(self):
        try:
lintangsutawika's avatar
fixes  
lintangsutawika committed
112
            if hasattr(self, "accelerator"):
113
                return self.accelerator.unwrap_model(self.model).config.n_ctx
114
            else:
115
                return self.model.config.n_ctx
116
117
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparently
lintangsutawika's avatar
fixes  
lintangsutawika committed
118
119
            if hasattr(self, "accelerator"):
                return self.accelerator.unwrap_model(
120
                    self.model
lintangsutawika's avatar
fixes  
lintangsutawika committed
121
                ).config.max_position_embeddings
122
            else:
123
                return self.model.config.max_position_embeddings
124

125
126
127
    @property
    def max_gen_toks(self):
        return 256
Leo Gao's avatar
Leo Gao committed
128

129
130
    @property
    def batch_size(self):
131
        return self.batch_size_per_gpu
Leo Gao's avatar
Leo Gao committed
132

133
134
135
    @property
    def device(self):
        return self._device
Leo Gao's avatar
Leo Gao committed
136

137
138
139
140
141
142
143
    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size
Leo Gao's avatar
Leo Gao committed
144

145
146
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
147

148
149
150
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

Leo Gao's avatar
Leo Gao committed
151
152
153
154
155
156
    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
157
        logits returned from the model
Leo Gao's avatar
Leo Gao committed
158
        """
159
        with torch.no_grad():
160
            return self.model(inps)[0]
161

162
    def _model_generate(self, context, max_length, eos_token_id, **generation_kwargs):
163
        # we require users to pass do_sample=True explicitly
164
165
166
        # for non-greedy gen. This should be reevaluated when considering beam search.
        if "do_sample" not in generation_kwargs.keys():
            generation_kwargs["do_sample"] = False
167
        if hasattr(self, "accelerator"):
168
            return self.accelerator.unwrap_model(self.model).generate(
169
170
171
172
                context,
                max_length=max_length,
                pad_token_id=eos_token_id,
                eos_token_id=eos_token_id,
173
                **generation_kwargs,
174
175
            )
        else:
176
            return self.model.generate(
177
178
179
180
                context,
                max_length=max_length,
                pad_token_id=eos_token_id,
                eos_token_id=eos_token_id,
181
                **generation_kwargs,
182
            )
183

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
lintangsutawika's avatar
fixes  
lintangsutawika committed
204
        for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
205
206
207
208
209
210
211
212
213
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
214
                )
215
            )
216

217
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]
218

219
220
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
221

lintangsutawika's avatar
fixes  
lintangsutawika committed
222
            pad_amnt = 0
223
            if self.world_size > 1:
lintangsutawika's avatar
fixes  
lintangsutawika committed
224
225
226
227
                # TODO: Comment on what we do here
                mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
                gathered = (
                    self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
228
229
                )

230
231
                pad_amnt = max(gathered) - gathered[self.rank]
                if pad_amnt > 0:
lintangsutawika's avatar
fixes  
lintangsutawika committed
232
                    rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
233
234
235
236
237

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows, disable_tqdm=True
            )

238
239
            if (self.world_size > 1) and (pad_amnt > 0):
                string_nll = [x[0] for x in string_nll[:-pad_amnt]]
240
241
242
243
            else:
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]

244
245
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

        return loglikelihoods

    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

        # TODO: automatic (variable) batch size detection for vectorization
        re_ord = utils.Reorderer(requests, _collate)
        for chunk in utils.chunks(
267
268
            tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
            self.batch_size,
269
        ):
lintangsutawika's avatar
fixes  
lintangsutawika committed
270

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
            inps = []
            cont_toks_list = []
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
290
                # model  \               \
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
                ).to(self.device)
                (inplen,) = inp.shape

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )

                # pad length from seq to padding_length
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )

                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
                inplens.append(inplen)

            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]

            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
            ):

                # Slice to original seq length
                contlen = len(cont_toks)
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
                max_equal = (greedy_tokens == cont_toks).all()

                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]

                # Answer: (log prob, is-exact-match)
                answer = (float(logits.sum()), bool(max_equal))

                res.append(answer)

        return re_ord.get_original(res)

    def greedy_until(self, requests):
        # TODO: implement fully general `until` that handles until that are
        #       multiple tokens or that span multiple tokens correctly

        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
            return len(toks), x[0]

        re_ord = utils.Reorderer([req.args for req in requests], _collate)

370
371
        for context, gen_kwargs in tqdm(re_ord.get_reordered()):
            if isinstance(gen_kwargs, dict):
372
                gen_kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
373
374
375
376
377
378
379
                if "until" in gen_kwargs.keys():
                    until = gen_kwargs.pop("until")
                    if isinstance(until, str):
                        until = [gen_kwargs]
                    elif not isinstance(until, list):
                        raise ValueError(
                            f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
380
                        )
381
            else:
382
383
384
                raise ValueError(
                    f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
                )
385
386
387
388
389
390
            if not until:
                until = [self.tok_decode(self.eot_token_id)]
            if "max_gen_toks" in gen_kwargs.keys():
                max_gen_toks = gen_kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks
391

392
393
            try:
                (primary_until,) = self.tok_encode(until[0])
394
            except Exception:
395
396
397
                # if our primary until would be multiple tokens long, we'll have errors.
                # TODO: handling this better will let us stop generating earlier + often.
                primary_until = self.eot_token_id
398
399

            context_enc = torch.tensor(
400
                [self.tok_encode(context)[max_gen_toks - self.max_length :]]
401
402
403
            ).to(self.device)

            cont = self._model_generate(
404
405
                context=context_enc,
                max_length=context_enc.shape[1] + max_gen_toks,
406
407
                eos_token_id=primary_until,
                **gen_kwargs,
408
409
410
411
412
413
414
415
            )

            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])

            for term in until:
                s = s.split(term)[0]

            res.append(s)
416

417
        return re_ord.get_original(res)