huggingface.py 34.9 KB
Newer Older
1
2
3
import torch
import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
4
from peft import __version__ as PEFT_VERSION, PeftModel
5
6

import copy
7
from collections import defaultdict
8
from tqdm import tqdm
9
from pathlib import Path
10
11
12
13
14
15
16
17
18
19

import torch.nn.functional as F

from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model

from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria

Benjamin Fattori's avatar
Benjamin Fattori committed
20
from accelerate import Accelerator, find_executable_batch_size
21
from typing import List, Optional, Union
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46


def _get_accelerate_args(
    device_map_option: Optional[str] = "auto",
    max_memory_per_gpu: Optional[Union[int, str]] = None,
    max_cpu_memory: Optional[Union[int, str]] = None,
    offload_folder: Optional[str] = "./offload",
) -> dict:
    """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
    max_memory = {}
    if max_memory_per_gpu is not None:
        max_memory_per_gpu_map = {
            device_idx: max_memory_per_gpu
            for device_idx in range(torch.cuda.device_count())
        }
        max_memory.update(max_memory_per_gpu_map)
    if max_cpu_memory is not None:
        max_memory["cpu"] = max_cpu_memory

    args = {}
    if max_memory:
        args["max_memory"] = max_memory
    args["device_map"] = device_map_option
    args["offload_folder"] = offload_folder
    return args
47
48


49
@register_model("hf-auto", "hf", "huggingface")
50
class HFLM(LM):
51
52
53
54
55
56
57
    """
    An abstracted Huggingface model class. Enables usage with both models of
    `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.

    Supports data-parallel multi-GPU with HF Accelerate.
    """

58
    AUTO_MODEL_CLASS = None
59
    _DEFAULT_MAX_LENGTH = 2048
haileyschoelkopf's avatar
haileyschoelkopf committed
60

61
62
    def __init__(
        self,
63
64
65
66
67
68
        pretrained: Optional[str] = "gpt2",
        revision: Optional[str] = "main",
        subfolder: Optional[str] = None,
        tokenizer: Optional[str] = None,
        max_length: Optional[int] = None,
        device: Optional[str] = "cuda",
69
        dtype: Optional[Union[str, torch.dtype]] = "auto",
Benjamin Fattori's avatar
Benjamin Fattori committed
70
71
        batch_size: Optional[Union[int, str]] = 1,
        max_batch_size: Optional[int] = 64,
72
73
        low_cpu_mem_usage: Optional[bool] = True,
        trust_remote_code: Optional[bool] = False,
74
        # arguments used for splitting a model across GPUs naively.
75
76
        # only used if `parallelize=True`.
        parallelize: Optional[bool] = False,
77
78
79
80
        device_map_option: Optional[str] = "auto",
        max_memory_per_gpu: Optional[Union[int, str]] = None,
        max_cpu_memory: Optional[Union[int, str]] = None,
        offload_folder: Optional[str] = "./offload",
81
82
83
84
85
86
87
88
        # PEFT and quantization options
        peft: Optional[str] = None,
        load_in_8bit: Optional[bool] = False,
        load_in_4bit: Optional[bool] = False,
        bnb_4bit_quant_type: Optional[str] = None,
        bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
        gptq: Optional[Union[bool, str]] = False,
        gptq_use_triton: Optional[bool] = False,
89
90
91
92
93
    ):
        super().__init__()

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
Benjamin Fattori's avatar
Benjamin Fattori committed
94
        assert isinstance(batch_size, (int, str))
95
96

        gpus = torch.cuda.device_count()
97
        accelerator = Accelerator()
haileyschoelkopf's avatar
haileyschoelkopf committed
98

99
        if not (parallelize or accelerator.num_processes > 1):
100
            # use user-passed device
101
102
103
104
            device_list = set(
                ["cuda", "cpu"]
                + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
            )
105
            if device:
106
                if device not in device_list:
107
108
109
110
111
112
113
114
115
116
117
                    device = int(device)
                self._device = torch.device(device)
                eval_logger.info(f"Using device '{device}'")
            else:
                eval_logger.info("Device not specified")
                eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
                self._device = (
                    torch.device("cuda")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
118
        else:
119
120
121
122
            if device != "cuda":
                eval_logger.info(
                    f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
                )
123
            # TODO: include in warning that `load_in_8bit` etc. affect this too
124
125
126
            self._device = device

        model_kwargs = {}
127
        if parallelize:
128
129
130
131
132
133
            model_kwargs = _get_accelerate_args(
                device_map_option,
                max_memory_per_gpu,
                max_cpu_memory,
                offload_folder,
            )
134
135
136
137
138
139
140

        # TODO: update this to be less of a hack once subfolder is fixed in HF
        revision = revision + ("/" + subfolder if subfolder is not None else "")

        self._config = transformers.AutoConfig.from_pretrained(
            pretrained,
            revision=revision,
141
            trust_remote_code=trust_remote_code,
142
143
144
145
146
        )

        if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
            self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
        else:
haileyschoelkopf's avatar
haileyschoelkopf committed
147
            self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
148

haileyschoelkopf's avatar
haileyschoelkopf committed
149
150
151
152
        assert self.AUTO_MODEL_CLASS in [
            transformers.AutoModelForCausalLM,
            transformers.AutoModelForSeq2SeqLM,
        ]
153

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        if not gptq:
            if load_in_4bit:
                assert (
                    transformers.__version__ >= "4.30.0"
                ), "load_in_4bit requires transformers >= 4.30.0"
            if transformers.__version__ >= "4.30.0":
                model_kwargs["load_in_4bit"] = load_in_4bit
                if load_in_4bit:
                    if bnb_4bit_quant_type:
                        model_kwargs["bnb_4bit_quant_type"] = bnb_4bit_quant_type
                    if bnb_4bit_compute_dtype:
                        model_kwargs["bnb_4bit_compute_dtype"] = utils.get_dtype(
                            bnb_4bit_compute_dtype
                        )
            self._model = self.AUTO_MODEL_CLASS.from_pretrained(
                pretrained,
                revision=revision,
                torch_dtype=utils.get_dtype(dtype),
                low_cpu_mem_usage=low_cpu_mem_usage,
                trust_remote_code=trust_remote_code,
                load_in_8bit=load_in_8bit,
                **model_kwargs,
            )
        else:
gk's avatar
gk committed
178
179
180
181
182
183
184
            try:
                from auto_gptq import AutoGPTQForCausalLM
            except ModuleNotFoundError:
                raise Exception(
                    "Tried to load auto_gptq, but auto-gptq is not installed ",
                    "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
                )
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

            self._model = AutoGPTQForCausalLM.from_quantized(
                pretrained,
                model_basename=None if gptq is True else Path(gptq).stem,
                low_cpu_mem_usage=low_cpu_mem_usage,
                trust_remote_code=trust_remote_code,
                use_safetensors=True if gptq is True else gptq.endswith(".safetensors"),
                use_triton=gptq_use_triton,
                warmup_triton=gptq_use_triton,
                **model_kwargs,
            )

        if peft:
            if load_in_4bit:
                assert PEFT_VERSION >= "0.4.0", "load_in_4bit requires peft >= 0.4.0"
            self._model = PeftModel.from_pretrained(
                self._model, peft, revision=revision
            )

204
        # forever after, access self._model through self.model property
205
        self.model.eval()
206
207
208
        self.model.tie_weights()
        if gpus <= 1 and not parallelize:
            # place model onto device, if not using HF Accelerate in any form
209
210
211
212
213
214
            try:
                self.model.to(self.device)
            except ValueError:
                eval_logger.info(
                    "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
                )
haileyschoelkopf's avatar
haileyschoelkopf committed
215

216
217
218
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            pretrained if tokenizer is None else tokenizer,
            revision=revision,
219
            trust_remote_code=trust_remote_code,
220
221
222
        )

        self.vocab_size = self.tokenizer.vocab_size
haileyschoelkopf's avatar
haileyschoelkopf committed
223
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
224

225
226
        self._max_length = max_length

Benjamin Fattori's avatar
Benjamin Fattori committed
227
228
229
230
231
232
233
234
235
236
        self.batch_schedule = 1
        self.batch_sizes = {}
        self.max_batch_size = max_batch_size

        if str(batch_size).startswith("auto"):
            batch_size = batch_size.split(":")
            self.batch_size_per_gpu = batch_size[0]
            self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
        else:
            self.batch_size_per_gpu = int(batch_size)
237
238
239
240
241
242
243
244
245
246
247

        # multigpu data-parallel support when launched with accelerate
        if gpus > 1:
            if parallelize:
                if accelerator.num_processes > 1:
                    raise RuntimeError(
                        "Attempted to use both a HF Accelerate `device_map` and to launch via `accelerate launch`. If this is the case, please either remove `parallelize=True` from --model_args or launch outside of the Accelerate launcher."
                    )
                else:
                    pass
            elif gpus > accelerator.num_processes:
248
                # TODO: make sure there's still never an edge case where we unintentionally default to CPU
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
                eval_logger.warning(
                    "WARNING: The number of total system GPUs does not match the number of spawned processes. "
                    "If you would like to use data parallelism, please launch the script "
                    "with 'accelerate launch *script*'. "
                    f"Current run will proceed with {accelerator.num_processes} devices."
                )
                self._rank = accelerator.local_process_index
                self._world_size = accelerator.num_processes
                # manually set model to use gpu, for case where many GPUs available but
                # only seek to use one
                self._device = (
                    torch.device(f"cuda:{accelerator.local_process_index}")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
264
265
266
267
268
269
                try:
                    self.model.to(self.device)
                except ValueError:
                    eval_logger.info(
                        "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
                    )
270
            else:
haileyschoelkopf's avatar
haileyschoelkopf committed
271
                self._model = accelerator.prepare(self.model)
272
273
274
275
276
277
278
279
                self._device = torch.device(f"cuda:{accelerator.local_process_index}")
                self.accelerator = accelerator

                if self.accelerator.is_local_main_process:
                    eval_logger.info(f"Using {gpus} devices with data parallelism")

                self._rank = self.accelerator.local_process_index
                self._world_size = self.accelerator.num_processes
haileyschoelkopf's avatar
haileyschoelkopf committed
280

281
282
283
284
285
    @property
    def config(self):
        # return the associated transformers.AutoConfig for the given pretrained model.
        return self._config

286
287
288
289
290
291
292
293
    @property
    def model(self):
        # returns the model, unwrapping it if using Accelerate
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        else:
            return self._model

294
295
296
297
298
299
300
    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
301
302
303
304
305
306
307
308
309
310
311
        if self._max_length:  # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
            if hasattr(self.model.config, attr):
                return getattr(self.model.config, attr)
        if hasattr(self.tokenizer, "model_max_length"):
            if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH
312

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

Benjamin Fattori's avatar
Benjamin Fattori committed
333
334
335
336
337
338
339
340
    def _detect_batch_size(self, requests=None, pos=0):
        if requests:
            _, context_enc, continuation_enc = requests[pos]
            max_length = len(
                (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
            )
        else:
            max_length = self.max_length
Benjamin Fattori's avatar
Benjamin Fattori committed
341
            
Benjamin Fattori's avatar
Benjamin Fattori committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        # if OOM, then halves batch_size and tries again
        @find_executable_batch_size(starting_batch_size=self.max_batch_size)
        def forward_batch(batch_size):
            test_batch = torch.ones((batch_size, max_length), device=self.device).long()
            for _ in range(5):
                _ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
            return batch_size

        batch_size = forward_batch()
        utils.clear_torch_cache()

        return batch_size


356
    def tok_encode(self, string: str, left_truncate_len=None):
haileyschoelkopf's avatar
haileyschoelkopf committed
357
        """ """
358
359
360
361
362
363
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
            add_special_tokens = False
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
            add_special_tokens = True

        encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
haileyschoelkopf's avatar
haileyschoelkopf committed
364

365
366
367
        # left-truncate the encoded context to be at most `left_truncate_len` tokens long
        if left_truncate_len:
            encoding = encoding[-left_truncate_len:]
haileyschoelkopf's avatar
haileyschoelkopf committed
368

369
370
        return encoding

haileyschoelkopf's avatar
haileyschoelkopf committed
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    def tok_batch_encode(
        self, strings: List[str], padding_side="left", left_truncate_len=None
    ):
        # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
        old_padding_side = self.tokenizer.padding_side
        self.tokenizer.padding_side = padding_side

        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
            add_special_tokens = False
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
            add_special_tokens = True

        encoding = self.tokenizer(
            strings,
            padding="longest",
            return_tensors="pt",
            add_special_tokens=add_special_tokens,
        )
        if left_truncate_len:
            encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
            encoding["attention_mask"] = encoding["attention_mask"][
                :, -left_truncate_len:
            ]
        self.tokenizer.padding_side = old_padding_side

        return encoding["input_ids"], encoding["attention_mask"]

398
399
400
401
402
403
404
405
    def tok_decode(self, tokens):
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
            return self.tokenizer.decode(tokens)
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
            return self.tokenizer.decode(tokens, skip_special_tokens=True)

    def _model_call(self, inps, attn_mask=None, labels=None):
        """
haileyschoelkopf's avatar
haileyschoelkopf committed
406
        :param inps: torch.Tensor
407
408
409
410
411
412
413
414
415
416
417
418
419
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
            [batch, sequence_ctx]. the size of sequence may vary from call to call
        :param attn_mask: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :param labels: torch.Tensor, optional
            A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
            (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
        :return
            A torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model's decoder
        """
        with torch.no_grad():
420
421
            if attn_mask is not None or labels is not None:
                assert attn_mask is not None and labels is not None
422
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
haileyschoelkopf's avatar
haileyschoelkopf committed
423
424
425
                return self.model(
                    input_ids=inps, attention_mask=attn_mask, labels=labels
                ).logits
426
427
428
429
430
431
432
433
434
435
436
437
438
            else:
                assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                return self.model(inps).logits

    def _model_generate(self, context, max_length, stop, **generation_kwargs):
        # we require users to pass do_sample=True explicitly
        # for non-greedy gen. This should be reevaluated when considering beam search.
        if "do_sample" not in generation_kwargs.keys():
            generation_kwargs["do_sample"] = False
        # build stopping criteria
        stopping_criteria = stop_sequences_criteria(
            self.tokenizer, stop, 1, context.shape[0]
        )
439
440
441
442
443
444
445
446
        return self.model.generate(
            context,
            max_length=max_length,
            stopping_criteria=stopping_criteria,
            pad_token_id=self.eot_token_id,
            use_cache=True,
            **generation_kwargs,
        )
447
448
449

    def _select_cont_toks(self, logits, contlen=None, inplen=None):
        if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
450
451
452
            assert (
                contlen and inplen
            ), "Must pass input len and cont. len to select scored logits for causal LM"
453
454
455
456
            # discard right-padding.
            # also discard the input/context tokens. we'll only score continuations.
            logits = logits[inplen - contlen : inplen]
        elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
457
458
459
460
            assert (
                contlen and not inplen
            ), "Selecting scored logits for Seq2SeqLM requires only cont. len"
            # only discard right-padding.
461
            # the logits input to this fn only contain decoder-side tokens.
haileyschoelkopf's avatar
haileyschoelkopf committed
462
463
            logits = logits[:contlen]

464
465
        return logits

466
467
468
469
470
471
472
473
474
475
476
    def _encode_pair(self, context, continuation):
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]
        whole_enc = self.tok_encode(context + continuation)
        context_enc = self.tok_encode(context)
        context_enc_len = len(context_enc)
        continuation_enc = whole_enc[context_enc_len:]
        return context_enc, continuation_enc

477
478
479
480
481
    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
482
483
484
                context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
                    continuation
                )
485
            else:
486
                context_enc, continuation_enc = self._encode_pair(context, continuation)
487
488
489
490
491
492
493

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        loglikelihoods = []
Benjamin Fattori's avatar
Benjamin Fattori committed
494
495
496
497
498
499
500
501
502

        adaptive_batch_size = None
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size

503
504
505
506
507
508
        for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
haileyschoelkopf's avatar
haileyschoelkopf committed
509
                        prefix_token=self.eot_token_id,
510
511
512
513
514
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
haileyschoelkopf's avatar
haileyschoelkopf committed
515
516

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            pad_amnt = 0
            if self.world_size > 1:
                # We pad out the external document-level iterator so the inner iterator doesn't hang
                mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
                gathered = (
                    self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
                )

                pad_amnt = max(gathered) - gathered[self.rank]
                if pad_amnt > 0:
                    rolling_token_windows += pad_amnt * [rolling_token_windows[0]]

            string_nll = self._loglikelihood_tokens(
532
                rolling_token_windows, disable_tqdm=True, override_bs=adaptive_batch_size
533
534
535
536
537
538
539
540
541
542
543
544
545
            )

            if (self.world_size > 1) and (pad_amnt > 0):
                string_nll = [x[0] for x in string_nll[:-pad_amnt]]
            else:
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

Benjamin Fattori's avatar
Benjamin Fattori committed
546
    def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs=None):
547
548
549
550
551
552
553
554
555
556
557
558
559
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
            return -len(toks), tuple(toks)
Benjamin Fattori's avatar
Benjamin Fattori committed
560
    
561
        re_ord = utils.Reorderer(requests, _collate)
Benjamin Fattori's avatar
Benjamin Fattori committed
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576

        n_reordered_requests = len(re_ord.get_reordered())
        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
        def _batch_scheduler(pos):
            sched = pos // int(n_reordered_requests / self.batch_schedule)
            if sched in self.batch_sizes:
                return self.batch_sizes[sched]
            print(
                f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
            )
            self.batch_sizes[sched] = self._detect_batch_size(re_ord.get_reordered(), pos)
            print(f"Determined largest batch size: {self.batch_sizes[sched]}")
            return self.batch_sizes[sched]    

577
578
        for chunk in utils.chunks(
            tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
Benjamin Fattori's avatar
Benjamin Fattori committed
579
580
581
582
583
584
585
586
            n=self.batch_size
            if self.batch_size != "auto"
            else override_bs
            if override_bs is not None
            else 0,
            fn=_batch_scheduler
            if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs
            else None,
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
        ):
            inps = []
            cont_toks_list = []
            inplens = []

            conts = []
            encoder_attns = []

            padding_len_inp = None
            padding_len_cont = None
            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

haileyschoelkopf's avatar
haileyschoelkopf committed
607
                # how this all works (illustrated on a causal decoder-only setup):
608
609
610
611
612
613
614
615
616
617
618
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
                # model  \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    inp = torch.tensor(
                        (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                        dtype=torch.long,
619
620
                        device=self.device,
                    )
621
622
623
624
625
                    (inplen,) = inp.shape
                elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                    inp = torch.tensor(
                        (context_enc)[-self.max_length :],
                        dtype=torch.long,
haileyschoelkopf's avatar
haileyschoelkopf committed
626
                        device=self.device,
627
                    )
628
                    (inplen,) = inp.shape
629
630
631
632

                    # build encoder attn masks
                    encoder_attns.append(torch.ones_like(inp))

633
                    cont = torch.tensor(
haileyschoelkopf's avatar
haileyschoelkopf committed
634
                        (continuation_enc)[-self.max_length :],
635
636
                        # TODO: left-shift these?
                        # TODO: our code assumes we never end up truncating conts for either model type
637
                        dtype=torch.long,
638
639
                        device=self.device,
                    )
640
641
                    (contlen,) = cont.shape

642
643
                    conts.append(cont)

haileyschoelkopf's avatar
haileyschoelkopf committed
644
645
646
647
648
                    padding_len_cont = (
                        max(padding_len_cont, contlen)
                        if padding_len_cont is not None
                        else contlen
                    )
649

haileyschoelkopf's avatar
haileyschoelkopf committed
650
651
652
653
654
                padding_len_inp = (
                    max(padding_len_inp, inplen)
                    if padding_len_inp is not None
                    else inplen
                )
655
656
657
658

                inps.append(inp)  # [1, inp_length]
                cont_toks_list.append(continuation_enc)
                inplens.append(inplen)
haileyschoelkopf's avatar
haileyschoelkopf committed
659

660
661
662
            # create encoder attn mask and batched conts, if seq2seq
            call_kwargs = {}
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
haileyschoelkopf's avatar
haileyschoelkopf committed
663
664
665
                batched_inps = utils.pad_and_concat(
                    padding_len_inp, inps, padding_side="right"
                )  # [batch, padding_len_inp]
666
667
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # TODO: left-pad encoder inps and mask?
haileyschoelkopf's avatar
haileyschoelkopf committed
668
669
670
671
672
673
674
675
676
677
678
679
680
                batched_inps = utils.pad_and_concat(
                    padding_len_inp, inps
                )  # [batch, padding_len_inp]
                batched_conts = utils.pad_and_concat(
                    padding_len_cont, conts
                )  # [batch, padding_len_cont]
                batched_encoder_mask = utils.pad_and_concat(
                    padding_len_inp, encoder_attns
                )  # [batch, padding_len_inp]
                call_kwargs = {
                    "attn_mask": batched_encoder_mask,
                    "labels": batched_conts,
                }
681
682
683

            multi_logits = F.log_softmax(
                self._model_call(batched_inps, **call_kwargs), dim=-1
684
            )  # [batch, padding_length (inp or cont), vocab]
685
686
687
688
689
690

            for (cache_key, _, _), logits, inplen, cont_toks in zip(
                chunk, multi_logits, inplens, cont_toks_list
            ):
                # Slice to original seq length
                contlen = len(cont_toks)
haileyschoelkopf's avatar
haileyschoelkopf committed
691
                # take only logits in the continuation
692
                # (discard context toks if decoder-only ; discard right-padding)
haileyschoelkopf's avatar
haileyschoelkopf committed
693
694
695
696
697
                ctx_len = (
                    inplen
                    if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                    else None
                )
698
                logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
haileyschoelkopf's avatar
haileyschoelkopf committed
699
                logits = logits.unsqueeze(0)  # [1, seq, vocab]
700
701
702

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)
703
704
705
                cont_toks = torch.tensor(
                    cont_toks, dtype=torch.long, device=self.device
                ).unsqueeze(
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
                    0
                )  # [1, seq]
                max_equal = (greedy_tokens == cont_toks).all()

                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]

                # Answer: (log prob, is-exact-match)
                answer = (float(logits.sum()), bool(max_equal))

                res.append(answer)

haileyschoelkopf's avatar
haileyschoelkopf committed
721
722
                self.cache_hook.add_partial("loglikelihood", cache_key, answer)

723
724
725
        return re_ord.get_original(res)

    def greedy_until(self, requests):
726
727
        res = defaultdict(list)
        re_ords = {}
728
729

        def _collate(x):
730
731
732
733
734
735
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
736
            toks = self.tok_encode(x[0])
haileyschoelkopf's avatar
haileyschoelkopf committed
737
            return -len(toks), x[0]
738

739
740
741
        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
742
743
        grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
        for key, reqs in grouper.get_grouped().items():
744
            # within each set of reqs for given kwargs, we reorder by token length, descending.
745
            re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
746

747
748
749
        pbar = tqdm(total=len(requests), disable=(self.rank != 0))

        # for each different set of kwargs, we execute all requests, by batch.
750
751
        for key, re_ord in re_ords.items():
            for chunk in utils.chunks(
haileyschoelkopf's avatar
haileyschoelkopf committed
752
                re_ord.get_reordered(),
753
754
755
                self.batch_size,
            ):
                contexts, all_gen_kwargs = zip(*chunk)
756
757
758
759
                # we assume all gen kwargs in the batch are the same
                # this is safe to assume because the `grouper` object ensures it.
                gen_kwargs = all_gen_kwargs[0]
                # unpack our keyword arguments.
760
761
762
763
764
765
766
767
768
                until = None
                if isinstance(gen_kwargs, dict):
                    kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                    if "until" in kwargs.keys():
                        until = kwargs.pop("until")
                        if isinstance(until, str):
                            until = [kwargs]
                        elif not isinstance(until, list):
                            raise ValueError(
769
                                f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
770
771
772
                            )
                else:
                    raise ValueError(
773
                        f"Expected `kwargs` to be of type `dict` but got {kwargs}"
774
775
776
777
778
779
780
781
782
                    )
                if not until:
                    until = [self.tok_decode(self.eot_token_id)]
                if "max_gen_toks" in kwargs.keys():
                    max_gen_toks = kwargs.pop("max_gen_toks")
                else:
                    max_gen_toks = self.max_gen_toks
                # first stop sequence is used to halt generation upon encountering
                (primary_until) = until[0]
783

784
                # set the max length in tokens of inputs ("context_enc")
haileyschoelkopf's avatar
haileyschoelkopf committed
785
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
786
787
788
789
790
                    # max len for inputs = max length, minus room to generate the max new tokens
                    max_ctx_len = self.max_length - max_gen_toks
                elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                    # max len for inputs = encoder's whole max_length
                    max_ctx_len = self.max_length
791

792
                # encode, pad, and truncate contexts for this batch
793
794
795
796
797
798
                context_enc, attn_masks = self.tok_batch_encode(
                    contexts, left_truncate_len=max_ctx_len
                )
                context_enc = context_enc.to(self.device)
                attn_masks = attn_masks.to(self.device)

799
                # perform batched generation
800
801
802
803
804
805
806
                cont = self._model_generate(
                    context=context_enc,
                    attention_mask=attn_masks,
                    max_length=context_enc.shape[1] + max_gen_toks,
                    stop=primary_until,
                    **kwargs,
                )
807

808
809
810
811
812
                cont_toks_list = cont.tolist()
                for cont_toks, context in zip(cont_toks_list, contexts):
                    # discard context + left-padding toks if using causal decoder-only LM
                    if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                        cont_toks = cont_toks[context_enc.shape[1] :]
813

814
                    s = self.tok_decode(cont_toks)
815

816
817
                    # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
                    for term in until:
818
819
820
                        if len(term) > 0:
                            # ignore '' separator,
                            # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
821
                            s = s.split(term)[0]
822

823
                    res[key].append(s)
824

825
826
827
828
                    self.cache_hook.add_partial(
                        "greedy_until", (context, gen_kwargs), s
                    )
                    pbar.update(1)
829
            # reorder this group of results back to original unsorted form
830
            res[key] = re_ord.get_original(res[key])
831

832
        pbar.close()
833

834
        return grouper.get_original(res)