hf_causal.py 15.1 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import torch
Xingjian Shi's avatar
Xingjian Shi committed
2
import transformers
Jason Phang's avatar
gpt3  
Jason Phang committed
3

4
import copy
5
from tqdm import tqdm
Jason Phang's avatar
gpt3  
Jason Phang committed
6

7
8
9
import torch.nn.functional as F

from lm_eval import utils
lintangsutawika's avatar
lintangsutawika committed
10
from lm_eval.logger import eval_logger
11
12
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
13

14
15
from accelerate import Accelerator
from itertools import islice
16
17


18
@register_model("hf-causal")
19
class HFLM(LM):
Fabrizio Milo's avatar
Fabrizio Milo committed
20
21
22
23
24
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
Xingjian Shi's avatar
Xingjian Shi committed
25
        low_cpu_mem_usage=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
26
27
28
29
        subfolder=None,
        tokenizer=None,
        batch_size=1,
    ):
Leo Gao's avatar
Leo Gao committed
30
        super().__init__()
31
32
33
34
35

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
        assert isinstance(batch_size, int)

36
37
        gpus = torch.cuda.device_count()
        if gpus <= 1:
38
            if device:
39
40
41
                if device not in ["cuda", "cpu"]:
                    device = int(device)
                self._device = torch.device(device)
42
                eval_logger.info(f"Using device '{device}'")
43
            else:
44
45
                eval_logger.info("Device not specified")
                eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
46
47
48
49
50
                self._device = (
                    torch.device("cuda")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
51
52
            self._rank = 0
            self._world_size = 1
53

54
        else:
55
            self._device = "cpu"
56

57
58
59
        # TODO: update this to be less of a hack once subfolder is fixed in HF
        revision = revision + ("/" + subfolder if subfolder is not None else "")

60
        self.model = transformers.AutoModelForCausalLM.from_pretrained(
Xingjian Shi's avatar
Xingjian Shi committed
61
            pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
62
        ).to(self.device)
63
        self.model.eval()
Leo Gao's avatar
Leo Gao committed
64

65
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
Fabrizio Milo's avatar
Fabrizio Milo committed
66
            pretrained if tokenizer is None else tokenizer,
67
            revision=revision,
Fabrizio Milo's avatar
Fabrizio Milo committed
68
        )
69

70
        self.vocab_size = self.tokenizer.vocab_size
71

72
        # multithreading and batching
73
        self.batch_size_per_gpu = batch_size  # todo: adaptive batch size
74

75
        # multigpu support with accelerate
76
        if gpus > 1:
lintangsutawika's avatar
fixes  
lintangsutawika committed
77
            accelerator = Accelerator()
78
            if gpus > accelerator.num_processes:
79
                eval_logger.warning(
lintangsutawika's avatar
fixes  
lintangsutawika committed
80
                    "WARNING: The number of total system GPUs does not match the number of spawned processes. "
81
82
                    "If you would like to use data parallelism, please launch the script "
                    "with 'accelerate launch *script*'. "
lintangsutawika's avatar
fixes  
lintangsutawika committed
83
                    f"Current run will proceed with {accelerator.num_processes} devices."
84
                )
85
86
                self._rank = accelerator.local_process_index
                self._world_size = accelerator.num_processes
87
            else:
88
                self.model = accelerator.prepare(self.model)
89
90
                self._device = torch.device(f"cuda:{accelerator.local_process_index}")
                self.accelerator = accelerator
91

92
                if self.accelerator.is_local_main_process:
93
                    eval_logger.info(f"Using {gpus} devices with data parallelism")
94

95
96
                self._rank = self.accelerator.local_process_index
                self._world_size = self.accelerator.num_processes
97

98
99
100
101
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id
102

103
104
105
    @property
    def max_length(self):
        try:
lintangsutawika's avatar
fixes  
lintangsutawika committed
106
            if hasattr(self, "accelerator"):
107
                return self.accelerator.unwrap_model(self.model).config.n_ctx
108
            else:
109
                return self.model.config.n_ctx
110
111
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparently
lintangsutawika's avatar
fixes  
lintangsutawika committed
112
113
            if hasattr(self, "accelerator"):
                return self.accelerator.unwrap_model(
114
                    self.model
lintangsutawika's avatar
fixes  
lintangsutawika committed
115
                ).config.max_position_embeddings
116
            else:
117
                return self.model.config.max_position_embeddings
118

119
120
121
    @property
    def max_gen_toks(self):
        return 256
Leo Gao's avatar
Leo Gao committed
122

123
124
    @property
    def batch_size(self):
125
        return self.batch_size_per_gpu
Leo Gao's avatar
Leo Gao committed
126

127
128
129
    @property
    def device(self):
        return self._device
Leo Gao's avatar
Leo Gao committed
130

131
132
133
134
135
136
137
    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size
Leo Gao's avatar
Leo Gao committed
138

139
140
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
141

142
143
144
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

Leo Gao's avatar
Leo Gao committed
145
146
147
148
149
150
    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
151
        logits returned from the model
Leo Gao's avatar
Leo Gao committed
152
        """
153
        with torch.no_grad():
154
            return self.model(inps)[0]
155

156
    def _model_generate(self, context, max_length, eos_token_id, **generation_kwargs):
157
        # we require users to pass do_sample=True explicitly
158
159
160
        # for non-greedy gen. This should be reevaluated when considering beam search.
        if "do_sample" not in generation_kwargs.keys():
            generation_kwargs["do_sample"] = False
161
        if hasattr(self, "accelerator"):
162
            return self.accelerator.unwrap_model(self.model).generate(
163
164
165
166
                context,
                max_length=max_length,
                pad_token_id=eos_token_id,
                eos_token_id=eos_token_id,
167
                **generation_kwargs,
168
169
            )
        else:
170
            return self.model.generate(
171
172
173
174
                context,
                max_length=max_length,
                pad_token_id=eos_token_id,
                eos_token_id=eos_token_id,
175
                **generation_kwargs,
176
            )
177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
lintangsutawika's avatar
fixes  
lintangsutawika committed
198
        for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
199
200
201
202
203
204
205
206
207
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
208
                )
209
            )
210

211
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]
212

213
214
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
215

lintangsutawika's avatar
fixes  
lintangsutawika committed
216
            pad_amnt = 0
217
            if self.world_size > 1:
lintangsutawika's avatar
fixes  
lintangsutawika committed
218
219
220
221
                # TODO: Comment on what we do here
                mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
                gathered = (
                    self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
222
223
                )

224
225
                pad_amnt = max(gathered) - gathered[self.rank]
                if pad_amnt > 0:
lintangsutawika's avatar
fixes  
lintangsutawika committed
226
                    rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
227
228
229
230
231

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows, disable_tqdm=True
            )

232
233
            if (self.world_size > 1) and (pad_amnt > 0):
                string_nll = [x[0] for x in string_nll[:-pad_amnt]]
234
235
236
237
            else:
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]

238
239
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

        return loglikelihoods

    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

        # TODO: automatic (variable) batch size detection for vectorization
        re_ord = utils.Reorderer(requests, _collate)
        for chunk in utils.chunks(
261
262
            tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
            self.batch_size,
263
        ):
lintangsutawika's avatar
fixes  
lintangsutawika committed
264

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
            inps = []
            cont_toks_list = []
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
284
                # model  \               \
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
                ).to(self.device)
                (inplen,) = inp.shape

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )

                # pad length from seq to padding_length
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )

                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
                inplens.append(inplen)

            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]

            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
            ):

                # Slice to original seq length
                contlen = len(cont_toks)
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
                max_equal = (greedy_tokens == cont_toks).all()

                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]

                # Answer: (log prob, is-exact-match)
                answer = (float(logits.sum()), bool(max_equal))

                res.append(answer)

        return re_ord.get_original(res)

    def greedy_until(self, requests):
        # TODO: implement fully general `until` that handles until that are
        #       multiple tokens or that span multiple tokens correctly

        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
            return len(toks), x[0]

        re_ord = utils.Reorderer([req.args for req in requests], _collate)

364
365
        for context, gen_kwargs in tqdm(re_ord.get_reordered()):
            if isinstance(gen_kwargs, dict):
366
                gen_kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
367
368
369
370
371
372
373
                if "until" in gen_kwargs.keys():
                    until = gen_kwargs.pop("until")
                    if isinstance(until, str):
                        until = [gen_kwargs]
                    elif not isinstance(until, list):
                        raise ValueError(
                            f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
374
                        )
375
            else:
376
377
378
                raise ValueError(
                    f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
                )
379
380
381
382
383
384
            if not until:
                until = [self.tok_decode(self.eot_token_id)]
            if "max_gen_toks" in gen_kwargs.keys():
                max_gen_toks = gen_kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks
385

386
387
            try:
                (primary_until,) = self.tok_encode(until[0])
388
            except Exception:
389
390
391
                # if our primary until would be multiple tokens long, we'll have errors.
                # TODO: handling this better will let us stop generating earlier + often.
                primary_until = self.eot_token_id
392
393

            context_enc = torch.tensor(
394
                [self.tok_encode(context)[max_gen_toks - self.max_length :]]
395
396
397
            ).to(self.device)

            cont = self._model_generate(
398
399
                context=context_enc,
                max_length=context_enc.shape[1] + max_gen_toks,
400
401
                eos_token_id=primary_until,
                **gen_kwargs,
402
403
404
405
406
407
408
409
            )

            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])

            for term in until:
                s = s.split(term)[0]

            res.append(s)
410

411
        return re_ord.get_original(res)