huggingface.py 28.4 KB
Newer Older
1
2
3
4
5
import torch
import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

import copy
6
from collections import defaultdict
7
8
9
10
11
12
13
14
15
16
17
18
from tqdm import tqdm

import torch.nn.functional as F

from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model

from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria

from accelerate import Accelerator
19
from typing import List, Optional, Union
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44


def _get_accelerate_args(
    device_map_option: Optional[str] = "auto",
    max_memory_per_gpu: Optional[Union[int, str]] = None,
    max_cpu_memory: Optional[Union[int, str]] = None,
    offload_folder: Optional[str] = "./offload",
) -> dict:
    """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
    max_memory = {}
    if max_memory_per_gpu is not None:
        max_memory_per_gpu_map = {
            device_idx: max_memory_per_gpu
            for device_idx in range(torch.cuda.device_count())
        }
        max_memory.update(max_memory_per_gpu_map)
    if max_cpu_memory is not None:
        max_memory["cpu"] = max_cpu_memory

    args = {}
    if max_memory:
        args["max_memory"] = max_memory
    args["device_map"] = device_map_option
    args["offload_folder"] = offload_folder
    return args
45
46


47
@register_model("hf-auto", "hf", "huggingface")
48
class HFLM(LM):
49
50
51
52
53
54
55
    """
    An abstracted Huggingface model class. Enables usage with both models of
    `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.

    Supports data-parallel multi-GPU with HF Accelerate.
    """

56
    AUTO_MODEL_CLASS = None
57
    _DEFAULT_MAX_LENGTH = 2048
haileyschoelkopf's avatar
haileyschoelkopf committed
58

59
60
61
62
63
64
    def __init__(
        self,
        device="cuda",
        pretrained="gpt2",
        revision="main",
        low_cpu_mem_usage=None,
65
        max_length=None,
66
67
68
        subfolder=None,
        tokenizer=None,
        batch_size=1,
69
70
        dtype: Optional[Union[str, torch.dtype]] = "auto",
        # arguments used for splitting a model across GPUs naively.
71
72
        # only used if `parallelize=True`.
        parallelize: Optional[bool] = False,
73
74
75
76
        device_map_option: Optional[str] = "auto",
        max_memory_per_gpu: Optional[Union[int, str]] = None,
        max_cpu_memory: Optional[Union[int, str]] = None,
        offload_folder: Optional[str] = "./offload",
77
78
79
80
81
82
83
84
    ):
        super().__init__()

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
        assert isinstance(batch_size, int)

        gpus = torch.cuda.device_count()
haileyschoelkopf's avatar
haileyschoelkopf committed
85

86
87
        if gpus <= 1 and not parallelize:
            # use user-passed device
88
89
90
91
92
93
94
95
96
97
98
99
100
            if device:
                if device not in ["cuda", "cpu"]:
                    device = int(device)
                self._device = torch.device(device)
                eval_logger.info(f"Using device '{device}'")
            else:
                eval_logger.info("Device not specified")
                eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
                self._device = (
                    torch.device("cuda")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
101
        else:
102
103
104
105
            eval_logger.info(
                f"Passed device '{device}', but using `accelerate launch` or `parallelize=True`. This will be overridden when placing model."
            )
            # TODO: include in warning that `load_in_8bit` etc. affect this too
106
107
108
            self._device = device

        model_kwargs = {}
109
        if parallelize:
110
111
112
113
114
115
            model_kwargs = _get_accelerate_args(
                device_map_option,
                max_memory_per_gpu,
                max_cpu_memory,
                offload_folder,
            )
116
117
118
119

        # TODO: update this to be less of a hack once subfolder is fixed in HF
        revision = revision + ("/" + subfolder if subfolder is not None else "")

haileyschoelkopf's avatar
haileyschoelkopf committed
120
        # get config
121
122
123
124
125
126
127
128
        self._config = transformers.AutoConfig.from_pretrained(
            pretrained,
            revision=revision,
        )

        if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
            self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
        else:
haileyschoelkopf's avatar
haileyschoelkopf committed
129
            self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
130

haileyschoelkopf's avatar
haileyschoelkopf committed
131
132
133
134
        assert self.AUTO_MODEL_CLASS in [
            transformers.AutoModelForCausalLM,
            transformers.AutoModelForSeq2SeqLM,
        ]
135

136
        self._model = self.AUTO_MODEL_CLASS.from_pretrained(
137
138
139
140
141
            pretrained,
            revision=revision,
            low_cpu_mem_usage=low_cpu_mem_usage,
            **model_kwargs,
            torch_dtype=utils.get_dtype(dtype),
142
        )
143
        # forever after, access self._model through self.model property
144
        self.model.eval()
145
146
147
148
        self.model.tie_weights()
        if gpus <= 1 and not parallelize:
            # place model onto device, if not using HF Accelerate in any form
            self.model.to(self.device)
haileyschoelkopf's avatar
haileyschoelkopf committed
149

150
151
152
153
154
155
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            pretrained if tokenizer is None else tokenizer,
            revision=revision,
        )

        self.vocab_size = self.tokenizer.vocab_size
haileyschoelkopf's avatar
haileyschoelkopf committed
156
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
157

158
159
        self._max_length = max_length

160
        # multithreading and batching
161
162
163
164
        self.batch_size_per_gpu = batch_size

        # multigpu data-parallel support when launched with accelerate
        if gpus > 1:
165
            accelerator = Accelerator()
166
167
168
169
170
171
172
173
            if parallelize:
                if accelerator.num_processes > 1:
                    raise RuntimeError(
                        "Attempted to use both a HF Accelerate `device_map` and to launch via `accelerate launch`. If this is the case, please either remove `parallelize=True` from --model_args or launch outside of the Accelerate launcher."
                    )
                else:
                    pass
            elif gpus > accelerator.num_processes:
174
                # TODO: make sure there's still never an edge case where we unintentionally default to CPU
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
                eval_logger.warning(
                    "WARNING: The number of total system GPUs does not match the number of spawned processes. "
                    "If you would like to use data parallelism, please launch the script "
                    "with 'accelerate launch *script*'. "
                    f"Current run will proceed with {accelerator.num_processes} devices."
                )
                self._rank = accelerator.local_process_index
                self._world_size = accelerator.num_processes
                # manually set model to use gpu, for case where many GPUs available but
                # only seek to use one
                self._device = (
                    torch.device(f"cuda:{accelerator.local_process_index}")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
                self.model.to(self.device)
            else:
haileyschoelkopf's avatar
haileyschoelkopf committed
192
                self._model = accelerator.prepare(self.model)
193
194
195
196
197
198
199
200
                self._device = torch.device(f"cuda:{accelerator.local_process_index}")
                self.accelerator = accelerator

                if self.accelerator.is_local_main_process:
                    eval_logger.info(f"Using {gpus} devices with data parallelism")

                self._rank = self.accelerator.local_process_index
                self._world_size = self.accelerator.num_processes
haileyschoelkopf's avatar
haileyschoelkopf committed
201

202
203
204
205
206
    @property
    def config(self):
        # return the associated transformers.AutoConfig for the given pretrained model.
        return self._config

207
208
209
210
211
212
213
214
    @property
    def model(self):
        # returns the model, unwrapping it if using Accelerate
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        else:
            return self._model

215
216
217
218
219
220
221
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
222
223
224
225
226
227
228
229
230
231
232
        if self._max_length:  # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
            if hasattr(self.model.config, attr):
                return getattr(self.model.config, attr)
        if hasattr(self.tokenizer, "model_max_length"):
            if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH
233

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

    def tok_encode(self, string: str, left_truncate_len=None):
haileyschoelkopf's avatar
haileyschoelkopf committed
255
        """ """
256
257
258
259
260
261
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
            add_special_tokens = False
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
            add_special_tokens = True

        encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
haileyschoelkopf's avatar
haileyschoelkopf committed
262

263
264
265
        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]
haileyschoelkopf's avatar
haileyschoelkopf committed
266

267
268
        return encoding

haileyschoelkopf's avatar
haileyschoelkopf committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    def tok_batch_encode(
        self, strings: List[str], padding_side="left", left_truncate_len=None
    ):
        # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
        old_padding_side = self.tokenizer.padding_side
        self.tokenizer.padding_side = padding_side

        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
            add_special_tokens = False
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
            add_special_tokens = True

        encoding = self.tokenizer(
            strings,
            padding="longest",
            return_tensors="pt",
            add_special_tokens=add_special_tokens,
        )
        if left_truncate_len:
            encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
            encoding["attention_mask"] = encoding["attention_mask"][
                :, -left_truncate_len:
            ]
        self.tokenizer.padding_side = old_padding_side

        return encoding["input_ids"], encoding["attention_mask"]

296
297
298
299
300
301
302
303
    def tok_decode(self, tokens):
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
            return self.tokenizer.decode(tokens)
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
            return self.tokenizer.decode(tokens, skip_special_tokens=True)

    def _model_call(self, inps, attn_mask=None, labels=None):
        """
haileyschoelkopf's avatar
haileyschoelkopf committed
304
        :param inps: torch.Tensor
305
306
307
308
309
310
311
312
313
314
315
316
317
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
            [batch, sequence_ctx]. the size of sequence may vary from call to call
        :param attn_mask: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :param labels: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :return
            A torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model's decoder
        """
        with torch.no_grad():
318
319
            if attn_mask is not None or labels is not None:
                assert attn_mask is not None and labels is not None
320
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
haileyschoelkopf's avatar
haileyschoelkopf committed
321
322
323
                return self.model(
                    input_ids=inps, attention_mask=attn_mask, labels=labels
                ).logits
324
325
326
327
328
329
330
331
332
333
334
335
336
            else:
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                return self.model(inps).logits

    def _model_generate(self, context, max_length, stop, **generation_kwargs):
        # we require users to pass do_sample=True explicitly
        # for non-greedy gen. This should be reevaluated when considering beam search.
        if "do_sample" not in generation_kwargs.keys():
            generation_kwargs["do_sample"] = False
        # build stopping criteria
        stopping_criteria = stop_sequences_criteria(
            self.tokenizer, stop, 1, context.shape[0]
        )
337
338
339
340
341
342
343
344
        return self.model.generate(
            context,
            max_length=max_length,
            stopping_criteria=stopping_criteria,
            pad_token_id=self.eot_token_id,
            use_cache=True,
            **generation_kwargs,
        )
345
346
347

    def _select_cont_toks(self, logits, contlen=None, inplen=None):
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
348
349
350
            assert (
                contlen and inplen
            ), "Must pass input len and cont. len to select scored logits for causal LM"
351
352
353
354
            # discard right-padding.
            # also discard the input/context tokens. we'll only score continuations.
            logits = logits[inplen - contlen : inplen]
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
355
356
357
358
            assert (
                contlen and not inplen
            ), "Selecting scored logits for Seq2SeqLM requires only cont. len"
            # only discard right-padding.
359
            # the logits input to this fn only contain decoder-side tokens.
haileyschoelkopf's avatar
haileyschoelkopf committed
360
361
            logits = logits[:contlen]

362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        return logits

    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        loglikelihoods = []
        for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
haileyschoelkopf's avatar
haileyschoelkopf committed
387
                        prefix_token=self.eot_token_id,
388
389
390
391
392
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
haileyschoelkopf's avatar
haileyschoelkopf committed
393
394

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            pad_amnt = 0
            if self.world_size > 1:
                # We pad out the external document-level iterator so the inner iterator doesn't hang
                mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
                gathered = (
                    self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
                )

                pad_amnt = max(gathered) - gathered[self.rank]
                if pad_amnt > 0:
                    rolling_token_windows += pad_amnt * [rolling_token_windows[0]]

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows, disable_tqdm=True
            )

            if (self.world_size > 1) and (pad_amnt > 0):
                string_nll = [x[0] for x in string_nll[:-pad_amnt]]
            else:
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

        # TODO: automatic (variable) batch size detection for vectorization
        re_ord = utils.Reorderer(requests, _collate)
        for chunk in utils.chunks(
            tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
            self.batch_size,
        ):

            inps = []
            cont_toks_list = []
            inplens = []

            conts = []
            encoder_attns = []

            padding_len_inp = None
            padding_len_cont = None
            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

haileyschoelkopf's avatar
haileyschoelkopf committed
465
                # how this all works (illustrated on a causal decoder-only setup):
466
467
468
469
470
471
472
473
474
475
476
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
                # model  \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    inp = torch.tensor(
                        (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                        dtype=torch.long,
477
478
                        device=self.device,
                    )
479
480
481
482
483
                    (inplen,) = inp.shape
                elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                    inp = torch.tensor(
                        (context_enc)[-self.max_length :],
                        dtype=torch.long,
haileyschoelkopf's avatar
haileyschoelkopf committed
484
                        device=self.device,
485
                    )
486
                    (inplen,) = inp.shape
487
488
489
490

                    # build encoder attn masks
                    encoder_attns.append(torch.ones_like(inp))

491
                    cont = torch.tensor(
haileyschoelkopf's avatar
haileyschoelkopf committed
492
                        (continuation_enc)[-self.max_length :],
493
494
                        # TODO: left-shift these?
                        # TODO: our code assumes we never end up truncating conts for either model type
495
                        dtype=torch.long,
496
497
                        device=self.device,
                    )
498
499
                    (contlen,) = cont.shape

500
501
                    conts.append(cont)

haileyschoelkopf's avatar
haileyschoelkopf committed
502
503
504
505
506
                    padding_len_cont = (
                        max(padding_len_cont, contlen)
                        if padding_len_cont is not None
                        else contlen
                    )
507

haileyschoelkopf's avatar
haileyschoelkopf committed
508
509
510
511
512
                padding_len_inp = (
                    max(padding_len_inp, inplen)
                    if padding_len_inp is not None
                    else inplen
                )
513
514
515
516

                inps.append(inp)  # [1, inp_length]
                cont_toks_list.append(continuation_enc)
                inplens.append(inplen)
haileyschoelkopf's avatar
haileyschoelkopf committed
517

518
519
520
            # create encoder attn mask and batched conts, if seq2seq
            call_kwargs = {}
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
521
522
523
                batched_inps = utils.pad_and_concat(
                    padding_len_inp, inps, padding_side="right"
                )  # [batch, padding_len_inp]
524
525
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # TODO: left-pad encoder inps and mask?
haileyschoelkopf's avatar
haileyschoelkopf committed
526
527
528
529
530
531
532
533
534
535
536
537
538
                batched_inps = utils.pad_and_concat(
                    padding_len_inp, inps
                )  # [batch, padding_len_inp]
                batched_conts = utils.pad_and_concat(
                    padding_len_cont, conts
                )  # [batch, padding_len_cont]
                batched_encoder_mask = utils.pad_and_concat(
                    padding_len_inp, encoder_attns
                )  # [batch, padding_len_inp]
                call_kwargs = {
                    "attn_mask": batched_encoder_mask,
                    "labels": batched_conts,
                }
539
540
541

            multi_logits = F.log_softmax(
                self._model_call(batched_inps, **call_kwargs), dim=-1
542
            )  # [batch, padding_length (inp or cont), vocab]
543
544
545
546
547
548
549

            for (cache_key, _, _), logits, inplen, cont_toks in zip(
                chunk, multi_logits, inplens, cont_toks_list
            ):

                # Slice to original seq length
                contlen = len(cont_toks)
haileyschoelkopf's avatar
haileyschoelkopf committed
550
                # take only logits in the continuation
551
                # (discard context toks if decoder-only ; discard right-padding)
haileyschoelkopf's avatar
haileyschoelkopf committed
552
553
554
555
556
                ctx_len = (
                    inplen
                    if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                    else None
                )
557
                logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
haileyschoelkopf's avatar
haileyschoelkopf committed
558
                logits = logits.unsqueeze(0)  # [1, seq, vocab]
559
560
561

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)
562
563
564
                cont_toks = torch.tensor(
                    cont_toks, dtype=torch.long, device=self.device
                ).unsqueeze(
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
                    0
                )  # [1, seq]
                max_equal = (greedy_tokens == cont_toks).all()

                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]

                # Answer: (log prob, is-exact-match)
                answer = (float(logits.sum()), bool(max_equal))

                res.append(answer)

haileyschoelkopf's avatar
haileyschoelkopf committed
580
581
                self.cache_hook.add_partial("loglikelihood", cache_key, answer)

582
583
584
        return re_ord.get_original(res)

    def greedy_until(self, requests):
585
586
        res = defaultdict(list)
        re_ords = {}
587
588

        def _collate(x):
589
590
591
592
593
594
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
595
            toks = self.tok_encode(x[0])
haileyschoelkopf's avatar
haileyschoelkopf committed
596
            return -len(toks), x[0]
597

598
599
600
        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
601
602
        grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
        for key, reqs in grouper.get_grouped().items():
603
            # within each set of reqs for given kwargs, we reorder by token length, descending.
604
            re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
605

606
607
608
        pbar = tqdm(total=len(requests), disable=(self.rank != 0))

        # for each different set of kwargs, we execute all requests, by batch.
609
610
        for key, re_ord in re_ords.items():
            for chunk in utils.chunks(
haileyschoelkopf's avatar
haileyschoelkopf committed
611
                re_ord.get_reordered(),
612
613
614
                self.batch_size,
            ):
                contexts, all_gen_kwargs = zip(*chunk)
615
616
617
618
                # we assume all gen kwargs in the batch are the same
                # this is safe to assume because the `grouper` object ensures it.
                gen_kwargs = all_gen_kwargs[0]
                # unpack our keyword arguments.
619
620
621
622
623
624
625
626
627
                until = None
                if isinstance(gen_kwargs, dict):
                    kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                    if "until" in kwargs.keys():
                        until = kwargs.pop("until")
                        if isinstance(until, str):
                            until = [kwargs]
                        elif not isinstance(until, list):
                            raise ValueError(
628
                                f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
629
630
631
                            )
                else:
                    raise ValueError(
632
                        f"Expected `kwargs` to be of type `dict` but got {kwargs}"
633
634
635
636
637
638
639
640
641
                    )
                if not until:
                    until = [self.tok_decode(self.eot_token_id)]
                if "max_gen_toks" in kwargs.keys():
                    max_gen_toks = kwargs.pop("max_gen_toks")
                else:
                    max_gen_toks = self.max_gen_toks
                # first stop sequence is used to halt generation upon encountering
                (primary_until) = until[0]
642

643
                # set the max length in tokens of inputs ("context_enc")
haileyschoelkopf's avatar
haileyschoelkopf committed
644
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
645
646
647
648
649
                    # max len for inputs = max length, minus room to generate the max new tokens
                    max_ctx_len = self.max_length - max_gen_toks
                elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                    # max len for inputs = encoder's whole max_length
                    max_ctx_len = self.max_length
650

651
                # encode, pad, and truncate contexts for this batch
652
653
654
655
656
657
                context_enc, attn_masks = self.tok_batch_encode(
                    contexts, left_truncate_len=max_ctx_len
                )
                context_enc = context_enc.to(self.device)
                attn_masks = attn_masks.to(self.device)

658
                # perform batched generation
659
660
661
662
663
664
665
                cont = self._model_generate(
                    context=context_enc,
                    attention_mask=attn_masks,
                    max_length=context_enc.shape[1] + max_gen_toks,
                    stop=primary_until,
                    **kwargs,
                )
666

667
668
669
670
671
                cont_toks_list = cont.tolist()
                for cont_toks, context in zip(cont_toks_list, contexts):
                    # discard context + left-padding toks if using causal decoder-only LM
                    if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                        cont_toks = cont_toks[context_enc.shape[1] :]
672

673
                    s = self.tok_decode(cont_toks)
674

675
676
                    # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
                    for term in until:
677
678
679
                        if len(term) > 0:
                            # ignore '' separator,
                            # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
680
                            s = s.split(term)[0]
681

682
                    res[key].append(s)
683

684
685
686
687
                    self.cache_hook.add_partial(
                        "greedy_until", (context, gen_kwargs), s
                    )
                    pbar.update(1)
688
            # reorder this group of results back to original unsorted form
689
            res[key] = re_ord.get_original(res[key])
690

691
        pbar.close()
692

693
        return grouper.get_original(res)