huggingface.py 35.6 KB
Newer Older
1
2
3
import torch
import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
4
from peft import __version__ as PEFT_VERSION, PeftModel
5
6

import copy
7
from collections import defaultdict
8
from tqdm import tqdm
9
from pathlib import Path
10
11
12
13
14
15
16
17
18
19

import torch.nn.functional as F

from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model

from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria

Benjamin Fattori's avatar
Benjamin Fattori committed
20
from accelerate import Accelerator, find_executable_batch_size
21
from typing import List, Optional, Union
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46


def _get_accelerate_args(
    device_map_option: Optional[str] = "auto",
    max_memory_per_gpu: Optional[Union[int, str]] = None,
    max_cpu_memory: Optional[Union[int, str]] = None,
    offload_folder: Optional[str] = "./offload",
) -> dict:
    """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
    max_memory = {}
    if max_memory_per_gpu is not None:
        max_memory_per_gpu_map = {
            device_idx: max_memory_per_gpu
            for device_idx in range(torch.cuda.device_count())
        }
        max_memory.update(max_memory_per_gpu_map)
    if max_cpu_memory is not None:
        max_memory["cpu"] = max_cpu_memory

    args = {}
    if max_memory:
        args["max_memory"] = max_memory
    args["device_map"] = device_map_option
    args["offload_folder"] = offload_folder
    return args
47
48


49
@register_model("hf-auto", "hf", "huggingface")
50
class HFLM(LM):
51
52
53
54
55
56
57
    """
    An abstracted Huggingface model class. Enables usage with both models of
    `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.

    Supports data-parallel multi-GPU with HF Accelerate.
    """

58
    AUTO_MODEL_CLASS = None
59
    _DEFAULT_MAX_LENGTH = 2048
haileyschoelkopf's avatar
haileyschoelkopf committed
60

61
62
    def __init__(
        self,
63
64
65
66
67
68
        pretrained: Optional[str] = "gpt2",
        revision: Optional[str] = "main",
        subfolder: Optional[str] = None,
        tokenizer: Optional[str] = None,
        max_length: Optional[int] = None,
        device: Optional[str] = "cuda",
69
        dtype: Optional[Union[str, torch.dtype]] = "auto",
Benjamin Fattori's avatar
Benjamin Fattori committed
70
71
        batch_size: Optional[Union[int, str]] = 1,
        max_batch_size: Optional[int] = 64,
72
73
        low_cpu_mem_usage: Optional[bool] = True,
        trust_remote_code: Optional[bool] = False,
haileyschoelkopf's avatar
haileyschoelkopf committed
74
        use_fast_tokenizer: Optional[bool] = True,
75
        # arguments used for splitting a model across GPUs naively.
76
77
        # only used if `parallelize=True`.
        parallelize: Optional[bool] = False,
78
79
80
81
        device_map_option: Optional[str] = "auto",
        max_memory_per_gpu: Optional[Union[int, str]] = None,
        max_cpu_memory: Optional[Union[int, str]] = None,
        offload_folder: Optional[str] = "./offload",
82
83
84
85
86
87
88
89
        # PEFT and quantization options
        peft: Optional[str] = None,
        load_in_8bit: Optional[bool] = False,
        load_in_4bit: Optional[bool] = False,
        bnb_4bit_quant_type: Optional[str] = None,
        bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
        gptq: Optional[Union[bool, str]] = False,
        gptq_use_triton: Optional[bool] = False,
90
91
92
93
94
    ):
        super().__init__()

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
Benjamin Fattori's avatar
Benjamin Fattori committed
95
        assert isinstance(batch_size, (int, str))
96
97

        gpus = torch.cuda.device_count()
98
        accelerator = Accelerator()
haileyschoelkopf's avatar
haileyschoelkopf committed
99

100
        if not (parallelize or accelerator.num_processes > 1):
101
            # use user-passed device
102
            device_list = set(
baberabb's avatar
add mps  
baberabb committed
103
                ["cuda", "cpu", "mps"]
104
105
                + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
            )
106
            if device:
107
                if device not in device_list:
108
109
110
                    device = int(device)
                self._device = torch.device(device)
                eval_logger.info(f"Using device '{device}'")
111
112
                if device == "mps":
                    eval_logger.info(
baberabb's avatar
baberabb committed
113
                        "MPS is still in beta and only supports float32; setting dtype to float32."
114
                    )
115
116
117
118
119
120
121
122
            else:
                eval_logger.info("Device not specified")
                eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
                self._device = (
                    torch.device("cuda")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
123
        else:
124
125
126
127
            if device != "cuda":
                eval_logger.info(
                    f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
                )
128
            # TODO: include in warning that `load_in_8bit` etc. affect this too
129
130
131
            self._device = device

        model_kwargs = {}
132
        if parallelize:
133
134
135
136
137
138
            model_kwargs = _get_accelerate_args(
                device_map_option,
                max_memory_per_gpu,
                max_cpu_memory,
                offload_folder,
            )
139
140
141
142
143
144
145

        # TODO: update this to be less of a hack once subfolder is fixed in HF
        revision = revision + ("/" + subfolder if subfolder is not None else "")

        self._config = transformers.AutoConfig.from_pretrained(
            pretrained,
            revision=revision,
146
            trust_remote_code=trust_remote_code,
147
148
149
150
151
        )

        if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
            self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
        else:
haileyschoelkopf's avatar
haileyschoelkopf committed
152
            self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
153

haileyschoelkopf's avatar
haileyschoelkopf committed
154
155
156
157
        assert self.AUTO_MODEL_CLASS in [
            transformers.AutoModelForCausalLM,
            transformers.AutoModelForSeq2SeqLM,
        ]
158

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        if not gptq:
            if load_in_4bit:
                assert (
                    transformers.__version__ >= "4.30.0"
                ), "load_in_4bit requires transformers >= 4.30.0"
            if transformers.__version__ >= "4.30.0":
                model_kwargs["load_in_4bit"] = load_in_4bit
                if load_in_4bit:
                    if bnb_4bit_quant_type:
                        model_kwargs["bnb_4bit_quant_type"] = bnb_4bit_quant_type
                    if bnb_4bit_compute_dtype:
                        model_kwargs["bnb_4bit_compute_dtype"] = utils.get_dtype(
                            bnb_4bit_compute_dtype
                        )
            self._model = self.AUTO_MODEL_CLASS.from_pretrained(
                pretrained,
                revision=revision,
                torch_dtype=utils.get_dtype(dtype),
                low_cpu_mem_usage=low_cpu_mem_usage,
                trust_remote_code=trust_remote_code,
                load_in_8bit=load_in_8bit,
                **model_kwargs,
            )
        else:
gk's avatar
gk committed
183
184
185
186
187
188
189
            try:
                from auto_gptq import AutoGPTQForCausalLM
            except ModuleNotFoundError:
                raise Exception(
                    "Tried to load auto_gptq, but auto-gptq is not installed ",
                    "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
                )
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

            self._model = AutoGPTQForCausalLM.from_quantized(
                pretrained,
                model_basename=None if gptq is True else Path(gptq).stem,
                low_cpu_mem_usage=low_cpu_mem_usage,
                trust_remote_code=trust_remote_code,
                use_safetensors=True if gptq is True else gptq.endswith(".safetensors"),
                use_triton=gptq_use_triton,
                warmup_triton=gptq_use_triton,
                **model_kwargs,
            )

        if peft:
            if load_in_4bit:
                assert PEFT_VERSION >= "0.4.0", "load_in_4bit requires peft >= 0.4.0"
            self._model = PeftModel.from_pretrained(
                self._model, peft, revision=revision
            )

209
        # forever after, access self._model through self.model property
210
        self.model.eval()
211
212
213
        self.model.tie_weights()
        if gpus <= 1 and not parallelize:
            # place model onto device, if not using HF Accelerate in any form
214
215
216
217
218
219
            try:
                self.model.to(self.device)
            except ValueError:
                eval_logger.info(
                    "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
                )
haileyschoelkopf's avatar
haileyschoelkopf committed
220

221
222
223
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            pretrained if tokenizer is None else tokenizer,
            revision=revision,
224
            trust_remote_code=trust_remote_code,
haileyschoelkopf's avatar
haileyschoelkopf committed
225
            use_fast=use_fast_tokenizer,
226
227
228
        )

        self.vocab_size = self.tokenizer.vocab_size
haileyschoelkopf's avatar
haileyschoelkopf committed
229
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
230

231
232
        self._max_length = max_length

Benjamin Fattori's avatar
Benjamin Fattori committed
233
234
235
236
237
238
239
240
241
242
        self.batch_schedule = 1
        self.batch_sizes = {}
        self.max_batch_size = max_batch_size

        if str(batch_size).startswith("auto"):
            batch_size = batch_size.split(":")
            self.batch_size_per_gpu = batch_size[0]
            self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
        else:
            self.batch_size_per_gpu = int(batch_size)
243
244
245
246
247
248
249
250
251
252
253

        # multigpu data-parallel support when launched with accelerate
        if gpus > 1:
            if parallelize:
                if accelerator.num_processes > 1:
                    raise RuntimeError(
                        "Attempted to use both a HF Accelerate `device_map` and to launch via `accelerate launch`. If this is the case, please either remove `parallelize=True` from --model_args or launch outside of the Accelerate launcher."
                    )
                else:
                    pass
            elif gpus > accelerator.num_processes:
254
                # TODO: make sure there's still never an edge case where we unintentionally default to CPU
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
                eval_logger.warning(
                    "WARNING: The number of total system GPUs does not match the number of spawned processes. "
                    "If you would like to use data parallelism, please launch the script "
                    "with 'accelerate launch *script*'. "
                    f"Current run will proceed with {accelerator.num_processes} devices."
                )
                self._rank = accelerator.local_process_index
                self._world_size = accelerator.num_processes
                # manually set model to use gpu, for case where many GPUs available but
                # only seek to use one
                self._device = (
                    torch.device(f"cuda:{accelerator.local_process_index}")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
270
271
272
273
274
275
                try:
                    self.model.to(self.device)
                except ValueError:
                    eval_logger.info(
                        "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
                    )
276
            else:
haileyschoelkopf's avatar
haileyschoelkopf committed
277
                self._model = accelerator.prepare(self.model)
278
279
280
281
282
283
284
285
                self._device = torch.device(f"cuda:{accelerator.local_process_index}")
                self.accelerator = accelerator

                if self.accelerator.is_local_main_process:
                    eval_logger.info(f"Using {gpus} devices with data parallelism")

                self._rank = self.accelerator.local_process_index
                self._world_size = self.accelerator.num_processes
haileyschoelkopf's avatar
haileyschoelkopf committed
286

287
288
289
290
291
    @property
    def config(self):
        # return the associated transformers.AutoConfig for the given pretrained model.
        return self._config

292
293
294
295
296
297
298
299
    @property
    def model(self):
        # returns the model, unwrapping it if using Accelerate
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        else:
            return self._model

300
301
302
303
304
305
306
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
307
308
309
310
311
312
313
314
315
316
317
        if self._max_length:  # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
            if hasattr(self.model.config, attr):
                return getattr(self.model.config, attr)
        if hasattr(self.tokenizer, "model_max_length"):
            if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH
318

319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

Benjamin Fattori's avatar
Benjamin Fattori committed
339
340
341
342
343
344
345
346
    def _detect_batch_size(self, requests=None, pos=0):
        if requests:
            _, context_enc, continuation_enc = requests[pos]
            max_length = len(
                (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
            )
        else:
            max_length = self.max_length
Benjamin Fattori's avatar
Benjamin Fattori committed
347
            
Benjamin Fattori's avatar
Benjamin Fattori committed
348
349
350
351
352
353
354
355
356
357
        # if OOM, then halves batch_size and tries again
        @find_executable_batch_size(starting_batch_size=self.max_batch_size)
        def forward_batch(batch_size):
            test_batch = torch.ones((batch_size, max_length), device=self.device).long()
            for _ in range(5):
                _ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
            return batch_size

        batch_size = forward_batch()

358
359
360
361
362
363
364
365
366
367
368
369
        if self.world_size > 1:
            # if multi-GPU, always take minimum over all selected batch sizes
            max_rnk_bs = torch.tensor([batch_size], device=self.device)
            gathered = (
                self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
            )
            batch_size = min(gathered)
            utils.clear_torch_cache()
            return batch_size
            

        utils.clear_torch_cache()
Benjamin Fattori's avatar
Benjamin Fattori committed
370
371
372
        return batch_size


373
    def tok_encode(self, string: str, left_truncate_len=None):
haileyschoelkopf's avatar
haileyschoelkopf committed
374
        """ """
375
376
377
378
379
380
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
            add_special_tokens = False
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
            add_special_tokens = True

        encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
haileyschoelkopf's avatar
haileyschoelkopf committed
381

382
383
384
        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]
haileyschoelkopf's avatar
haileyschoelkopf committed
385

386
387
        return encoding

haileyschoelkopf's avatar
haileyschoelkopf committed
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    def tok_batch_encode(
        self, strings: List[str], padding_side="left", left_truncate_len=None
    ):
        # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
        old_padding_side = self.tokenizer.padding_side
        self.tokenizer.padding_side = padding_side

        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
            add_special_tokens = False
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
            add_special_tokens = True

        encoding = self.tokenizer(
            strings,
            padding="longest",
            return_tensors="pt",
            add_special_tokens=add_special_tokens,
        )
        if left_truncate_len:
            encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
            encoding["attention_mask"] = encoding["attention_mask"][
                :, -left_truncate_len:
            ]
        self.tokenizer.padding_side = old_padding_side

        return encoding["input_ids"], encoding["attention_mask"]

415
416
417
418
419
420
421
422
    def tok_decode(self, tokens):
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
            return self.tokenizer.decode(tokens)
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
            return self.tokenizer.decode(tokens, skip_special_tokens=True)

    def _model_call(self, inps, attn_mask=None, labels=None):
        """
haileyschoelkopf's avatar
haileyschoelkopf committed
423
        :param inps: torch.Tensor
424
425
426
427
428
429
430
431
432
433
434
435
436
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
            [batch, sequence_ctx]. the size of sequence may vary from call to call
        :param attn_mask: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :param labels: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :return
            A torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model's decoder
        """
        with torch.no_grad():
437
438
            if attn_mask is not None or labels is not None:
                assert attn_mask is not None and labels is not None
439
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
haileyschoelkopf's avatar
haileyschoelkopf committed
440
441
442
                return self.model(
                    input_ids=inps, attention_mask=attn_mask, labels=labels
                ).logits
443
444
445
446
447
448
449
450
451
452
453
454
455
            else:
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                return self.model(inps).logits

    def _model_generate(self, context, max_length, stop, **generation_kwargs):
        # we require users to pass do_sample=True explicitly
        # for non-greedy gen. This should be reevaluated when considering beam search.
        if "do_sample" not in generation_kwargs.keys():
            generation_kwargs["do_sample"] = False
        # build stopping criteria
        stopping_criteria = stop_sequences_criteria(
            self.tokenizer, stop, 1, context.shape[0]
        )
456
457
458
459
460
461
462
463
        return self.model.generate(
            context,
            max_length=max_length,
            stopping_criteria=stopping_criteria,
            pad_token_id=self.eot_token_id,
            use_cache=True,
            **generation_kwargs,
        )
464
465
466

    def _select_cont_toks(self, logits, contlen=None, inplen=None):
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
467
468
469
            assert (
                contlen and inplen
            ), "Must pass input len and cont. len to select scored logits for causal LM"
470
471
472
473
            # discard right-padding.
            # also discard the input/context tokens. we'll only score continuations.
            logits = logits[inplen - contlen : inplen]
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
474
475
476
477
            assert (
                contlen and not inplen
            ), "Selecting scored logits for Seq2SeqLM requires only cont. len"
            # only discard right-padding.
478
            # the logits input to this fn only contain decoder-side tokens.
haileyschoelkopf's avatar
haileyschoelkopf committed
479
480
            logits = logits[:contlen]

481
482
        return logits

483
484
485
486
487
488
489
490
491
492
493
    def _encode_pair(self, context, continuation):
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]
        whole_enc = self.tok_encode(context + continuation)
        context_enc = self.tok_encode(context)
        context_enc_len = len(context_enc)
        continuation_enc = whole_enc[context_enc_len:]
        return context_enc, continuation_enc

494
495
496
497
498
    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
499
500
501
                context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
                    continuation
                )
502
            else:
503
                context_enc, continuation_enc = self._encode_pair(context, continuation)
504
505
506
507
508
509
510

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        loglikelihoods = []
Benjamin Fattori's avatar
Benjamin Fattori committed
511
512
513
514
515
516
517
518
519

        adaptive_batch_size = None
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size

520
521
522
523
524
525
        for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
haileyschoelkopf's avatar
haileyschoelkopf committed
526
                        prefix_token=self.eot_token_id,
527
528
529
530
531
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
haileyschoelkopf's avatar
haileyschoelkopf committed
532
533

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            pad_amnt = 0
            if self.world_size > 1:
                # We pad out the external document-level iterator so the inner iterator doesn't hang
                mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
                gathered = (
                    self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
                )

                pad_amnt = max(gathered) - gathered[self.rank]
                if pad_amnt > 0:
                    rolling_token_windows += pad_amnt * [rolling_token_windows[0]]

            string_nll = self._loglikelihood_tokens(
549
                rolling_token_windows, disable_tqdm=True, override_bs=adaptive_batch_size
550
551
552
553
554
555
556
557
558
559
560
561
562
            )

            if (self.world_size > 1) and (pad_amnt > 0):
                string_nll = [x[0] for x in string_nll[:-pad_amnt]]
            else:
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

Benjamin Fattori's avatar
Benjamin Fattori committed
563
    def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs=None):
564
565
566
567
568
569
570
571
572
573
574
575
576
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
            return -len(toks), tuple(toks)
Benjamin Fattori's avatar
Benjamin Fattori committed
577
    
578
        re_ord = utils.Reorderer(requests, _collate)
Benjamin Fattori's avatar
Benjamin Fattori committed
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593

        n_reordered_requests = len(re_ord.get_reordered())
        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
        def _batch_scheduler(pos):
            sched = pos // int(n_reordered_requests / self.batch_schedule)
            if sched in self.batch_sizes:
                return self.batch_sizes[sched]
            print(
                f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
            )
            self.batch_sizes[sched] = self._detect_batch_size(re_ord.get_reordered(), pos)
            print(f"Determined largest batch size: {self.batch_sizes[sched]}")
            return self.batch_sizes[sched]    

594
595
        for chunk in utils.chunks(
            tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
Benjamin Fattori's avatar
Benjamin Fattori committed
596
597
598
599
600
601
602
603
            n=self.batch_size
            if self.batch_size != "auto"
            else override_bs
            if override_bs is not None
            else 0,
            fn=_batch_scheduler
            if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs
            else None,
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
        ):
            inps = []
            cont_toks_list = []
            inplens = []

            conts = []
            encoder_attns = []

            padding_len_inp = None
            padding_len_cont = None
            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

haileyschoelkopf's avatar
haileyschoelkopf committed
624
                # how this all works (illustrated on a causal decoder-only setup):
625
626
627
628
629
630
631
632
633
634
635
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
                # model  \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    inp = torch.tensor(
                        (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                        dtype=torch.long,
636
637
                        device=self.device,
                    )
638
639
640
641
642
                    (inplen,) = inp.shape
                elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                    inp = torch.tensor(
                        (context_enc)[-self.max_length :],
                        dtype=torch.long,
haileyschoelkopf's avatar
haileyschoelkopf committed
643
                        device=self.device,
644
                    )
645
                    (inplen,) = inp.shape
646
647
648
649

                    # build encoder attn masks
                    encoder_attns.append(torch.ones_like(inp))

650
                    cont = torch.tensor(
haileyschoelkopf's avatar
haileyschoelkopf committed
651
                        (continuation_enc)[-self.max_length :],
652
653
                        # TODO: left-shift these?
                        # TODO: our code assumes we never end up truncating conts for either model type
654
                        dtype=torch.long,
655
656
                        device=self.device,
                    )
657
658
                    (contlen,) = cont.shape

659
660
                    conts.append(cont)

haileyschoelkopf's avatar
haileyschoelkopf committed
661
662
663
664
665
                    padding_len_cont = (
                        max(padding_len_cont, contlen)
                        if padding_len_cont is not None
                        else contlen
                    )
666

haileyschoelkopf's avatar
haileyschoelkopf committed
667
668
669
670
671
                padding_len_inp = (
                    max(padding_len_inp, inplen)
                    if padding_len_inp is not None
                    else inplen
                )
672
673
674
675

                inps.append(inp)  # [1, inp_length]
                cont_toks_list.append(continuation_enc)
                inplens.append(inplen)
haileyschoelkopf's avatar
haileyschoelkopf committed
676

677
678
679
            # create encoder attn mask and batched conts, if seq2seq
            call_kwargs = {}
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
680
681
682
                batched_inps = utils.pad_and_concat(
                    padding_len_inp, inps, padding_side="right"
                )  # [batch, padding_len_inp]
683
684
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # TODO: left-pad encoder inps and mask?
haileyschoelkopf's avatar
haileyschoelkopf committed
685
686
687
688
689
690
691
692
693
694
695
696
697
                batched_inps = utils.pad_and_concat(
                    padding_len_inp, inps
                )  # [batch, padding_len_inp]
                batched_conts = utils.pad_and_concat(
                    padding_len_cont, conts
                )  # [batch, padding_len_cont]
                batched_encoder_mask = utils.pad_and_concat(
                    padding_len_inp, encoder_attns
                )  # [batch, padding_len_inp]
                call_kwargs = {
                    "attn_mask": batched_encoder_mask,
                    "labels": batched_conts,
                }
698
699
700

            multi_logits = F.log_softmax(
                self._model_call(batched_inps, **call_kwargs), dim=-1
701
            )  # [batch, padding_length (inp or cont), vocab]
702
703
704
705
706
707

            for (cache_key, _, _), logits, inplen, cont_toks in zip(
                chunk, multi_logits, inplens, cont_toks_list
            ):
                # Slice to original seq length
                contlen = len(cont_toks)
haileyschoelkopf's avatar
haileyschoelkopf committed
708
                # take only logits in the continuation
709
                # (discard context toks if decoder-only ; discard right-padding)
haileyschoelkopf's avatar
haileyschoelkopf committed
710
711
712
713
714
                ctx_len = (
                    inplen
                    if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                    else None
                )
715
                logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
haileyschoelkopf's avatar
haileyschoelkopf committed
716
                logits = logits.unsqueeze(0)  # [1, seq, vocab]
717
718
719

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)
720
721
722
                cont_toks = torch.tensor(
                    cont_toks, dtype=torch.long, device=self.device
                ).unsqueeze(
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
                    0
                )  # [1, seq]
                max_equal = (greedy_tokens == cont_toks).all()

                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]

                # Answer: (log prob, is-exact-match)
                answer = (float(logits.sum()), bool(max_equal))

                res.append(answer)

haileyschoelkopf's avatar
haileyschoelkopf committed
738
739
                self.cache_hook.add_partial("loglikelihood", cache_key, answer)

740
741
742
        return re_ord.get_original(res)

    def greedy_until(self, requests):
743
744
        res = defaultdict(list)
        re_ords = {}
745
746

        def _collate(x):
747
748
749
750
751
752
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
753
            toks = self.tok_encode(x[0])
haileyschoelkopf's avatar
haileyschoelkopf committed
754
            return -len(toks), x[0]
755

756
757
758
        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
759
760
        grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
        for key, reqs in grouper.get_grouped().items():
761
            # within each set of reqs for given kwargs, we reorder by token length, descending.
762
            re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
763

764
765
766
        pbar = tqdm(total=len(requests), disable=(self.rank != 0))

        # for each different set of kwargs, we execute all requests, by batch.
767
768
        for key, re_ord in re_ords.items():
            for chunk in utils.chunks(
haileyschoelkopf's avatar
haileyschoelkopf committed
769
                re_ord.get_reordered(),
770
771
772
                self.batch_size,
            ):
                contexts, all_gen_kwargs = zip(*chunk)
773
774
775
776
                # we assume all gen kwargs in the batch are the same
                # this is safe to assume because the `grouper` object ensures it.
                gen_kwargs = all_gen_kwargs[0]
                # unpack our keyword arguments.
777
778
779
780
781
782
783
784
785
                until = None
                if isinstance(gen_kwargs, dict):
                    kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                    if "until" in kwargs.keys():
                        until = kwargs.pop("until")
                        if isinstance(until, str):
                            until = [kwargs]
                        elif not isinstance(until, list):
                            raise ValueError(
786
                                f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
787
788
789
                            )
                else:
                    raise ValueError(
790
                        f"Expected `kwargs` to be of type `dict` but got {kwargs}"
791
792
793
794
795
796
797
798
799
                    )
                if not until:
                    until = [self.tok_decode(self.eot_token_id)]
                if "max_gen_toks" in kwargs.keys():
                    max_gen_toks = kwargs.pop("max_gen_toks")
                else:
                    max_gen_toks = self.max_gen_toks
                # first stop sequence is used to halt generation upon encountering
                (primary_until) = until[0]
800

801
                # set the max length in tokens of inputs ("context_enc")
haileyschoelkopf's avatar
haileyschoelkopf committed
802
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
803
804
805
806
807
                    # max len for inputs = max length, minus room to generate the max new tokens
                    max_ctx_len = self.max_length - max_gen_toks
                elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                    # max len for inputs = encoder's whole max_length
                    max_ctx_len = self.max_length
808

809
                # encode, pad, and truncate contexts for this batch
810
811
812
813
814
815
                context_enc, attn_masks = self.tok_batch_encode(
                    contexts, left_truncate_len=max_ctx_len
                )
                context_enc = context_enc.to(self.device)
                attn_masks = attn_masks.to(self.device)

816
                # perform batched generation
817
818
819
820
821
822
823
                cont = self._model_generate(
                    context=context_enc,
                    attention_mask=attn_masks,
                    max_length=context_enc.shape[1] + max_gen_toks,
                    stop=primary_until,
                    **kwargs,
                )
824

825
826
827
828
829
                cont_toks_list = cont.tolist()
                for cont_toks, context in zip(cont_toks_list, contexts):
                    # discard context + left-padding toks if using causal decoder-only LM
                    if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                        cont_toks = cont_toks[context_enc.shape[1] :]
830

831
                    s = self.tok_decode(cont_toks)
832

833
834
                    # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
                    for term in until:
835
836
837
                        if len(term) > 0:
                            # ignore '' separator,
                            # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
838
                            s = s.split(term)[0]
839

840
                    res[key].append(s)
841

842
843
844
845
                    self.cache_hook.add_partial(
                        "greedy_until", (context, gen_kwargs), s
                    )
                    pbar.update(1)
846
            # reorder this group of results back to original unsorted form
847
            res[key] = re_ord.get_original(res[key])
848

849
        pbar.close()
850

851
        return grouper.get_original(res)