base.py 34.8 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import abc
2
from typing import Iterable
thefazzer's avatar
thefazzer committed
3
import numpy as np
4
import random
Leo Gao's avatar
Leo Gao committed
5
import re
6
7
8
import os
import json
import hashlib
Jonathan Tow's avatar
Jonathan Tow committed
9
import datasets
10
from tqdm import tqdm
11
import torch
Leo Gao's avatar
Leo Gao committed
12
import torch.nn.functional as F
13
from accelerate import find_executable_batch_size
&'s avatar
& committed
14

15
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
16
from lm_eval import utils
17
from abc import abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
18

Jason Phang's avatar
Jason Phang committed
19

Leo Gao's avatar
Leo Gao committed
20
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
21
22
23
    def __init__(self):
        self.cache_hook = CacheHook(None)

24
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
25
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
26
        """Compute log-likelihood of generating a continuation from a context.
Fabrizio Milo's avatar
Fabrizio Milo committed
27
        Downstream tasks should attempt to use loglikelihood instead of other
Leo Gao's avatar
Leo Gao committed
28
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
29

Leo Gao's avatar
Leo Gao committed
30
31
32
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Fabrizio Milo's avatar
Fabrizio Milo committed
33
                Context string. Implementations of LM must be able to handle an
Leo Gao's avatar
Leo Gao committed
34
                empty context string.
Leo Gao's avatar
Leo Gao committed
35
            continuation: str
Fabrizio Milo's avatar
Fabrizio Milo committed
36
37
                The continuation over which log likelihood will be calculated. If
                there is a word boundary, the space should be in the continuation.
Leo Gao's avatar
Leo Gao committed
38
39
40
41
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
42
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
43
            isgreedy:
Jason Phang's avatar
Jason Phang committed
44
45
46
47
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

48
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
49
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
50
51
52
53
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
Fabrizio Milo's avatar
Fabrizio Milo committed
54
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
Jason Phang's avatar
Jason Phang committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
87
88
89
        """
        pass

&'s avatar
& committed
90
    # TODO: Add an optional max length
91
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
92
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
93
94
95
96
97
98
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
99
            until: [str]
Fabrizio Milo's avatar
Fabrizio Milo committed
100
                The string sequences to generate until. These string sequences
Leo Gao's avatar
Leo Gao committed
101
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
102
103
104
105
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
106
        """
Leo Gao's avatar
Leo Gao committed
107
108
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
109
    @classmethod
110
111
    def create_from_arg_string(cls, arg_string, additional_config=None):
        additional_config = {} if additional_config is None else additional_config
112
113
114
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
115

Leo Gao's avatar
Leo Gao committed
116
117
118
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
119

120
class BaseLM(LM):
121
122
123
124
125
126
    def __init__(self):
        super().__init__()
        self.batch_schedule = 1
        self.batch_sizes = {}
        self.max_batch_size = 512

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    @property
    @abstractmethod
    def eot_token_id(self):
        pass

    @property
    @abstractmethod
    def max_length(self):
        pass

    @property
    @abstractmethod
    def max_gen_toks(self):
        pass

    @property
    @abstractmethod
    def batch_size(self):
        pass

    @property
    @abstractmethod
    def device(self):
        pass

152
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
153
154
155
    def tok_encode(self, string: str):
        pass

156
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
157
158
    def tok_decode(self, tokens: Iterable[int]):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
159

160
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
161
162
    def _model_generate(self, context, max_length, eos_token_id):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
163

164
165
    @abstractmethod
    def _model_call(self, inps):
Jason Phang's avatar
gpt3  
Jason Phang committed
166
        """
167
168
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Jason Phang's avatar
gpt3  
Jason Phang committed
169

170
        returns: a torch tensor of shape [batch, sequence, vocab] with the
171
        logits returned from the model
172
173
        """
        pass
174

175
176
177
    def _detect_batch_size(self, requests=None, pos=0):
        if requests:
            _, context_enc, continuation_enc = requests[pos]
178
179
180
            max_length = len(
                (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
            )
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        else:
            max_length = self.max_length

        # if OOM, then halves batch_size and tries again
        @find_executable_batch_size(starting_batch_size=self.max_batch_size)
        def forward_batch(batch_size):
            test_batch = torch.ones((batch_size, max_length), device=self.device).long()
            for _ in range(5):
                _ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
            return batch_size

        batch_size = forward_batch()
        utils.clear_torch_cache()

        return batch_size

Leo Gao's avatar
Leo Gao committed
197
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
198
199
    # TODO: enforce this somehow

gakada's avatar
gakada committed
200
    def _encode_pair(self, context, continuation):
201
202
203
204
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]
gakada's avatar
gakada committed
205
206
207
        whole_enc = self.tok_encode(context + continuation)
        context_enc = self.tok_encode(context)
        context_enc_len = len(context_enc)
208
        continuation_enc = whole_enc[context_enc_len:]
gakada's avatar
gakada committed
209
210
        return context_enc, continuation_enc

211
212
213
214
215
    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
216
217
218
                context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
                    continuation
                )
219
            else:
gakada's avatar
gakada committed
220
                context_enc, continuation_enc = self._encode_pair(context, continuation)
221
222
223
224
225
226
227

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
228

229
230
        # automatic batch size detection for vectorization
        adaptive_batch_size = None
231
        if self.batch_size == "auto":
232
            # using rolling window with maximum context
233
            print("Passed argument batch_size = auto. Detecting largest batch size")
234
            batch_size = self._detect_batch_size()
235
236
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size
237
238

        loglikelihoods = []
Fabrizio Milo's avatar
Fabrizio Milo committed
239
240
241
242
243
244
245
246
247
248
249
250
        for (string,) in tqdm(requests):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
251
252
253

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

254
255
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
Fabrizio Milo's avatar
Fabrizio Milo committed
256
            string_nll = self._loglikelihood_tokens(
257
258
259
                rolling_token_windows,
                disable_tqdm=True,
                override_bs=adaptive_batch_size,
Fabrizio Milo's avatar
Fabrizio Milo committed
260
261
            )

262
263
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
Fabrizio Milo's avatar
Fabrizio Milo committed
264

265
266
267
268
269
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

270
    def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs=None):
271
272
273
274
275
276
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
277
278
279
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
280
281
282
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
283
            return -len(toks), tuple(toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
284

Fabrizio Milo's avatar
Fabrizio Milo committed
285
        re_ord = utils.Reorderer(requests, _collate)
286

287
288
289
        reordered_requests = re_ord.get_reordered()
        n_reordered_requests = len(reordered_requests)

290
291
        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
292
293
294
295
        def _batch_scheduler(pos):
            sched = pos // int(n_reordered_requests / self.batch_schedule)
            if sched in self.batch_sizes:
                return self.batch_sizes[sched]
296
297
298
            print(
                f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
            )
299
300
301
            self.batch_sizes[sched] = self._detect_batch_size(reordered_requests, pos)
            print(f"Determined largest batch size: {self.batch_sizes[sched]}")
            return self.batch_sizes[sched]
302

Fabrizio Milo's avatar
Fabrizio Milo committed
303
        for chunk in utils.chunks(
304
            tqdm(reordered_requests, disable=disable_tqdm),
305
306
307
308
309
310
            n=self.batch_size
            if self.batch_size != "auto"
            else override_bs
            if override_bs is not None
            else 0,
            fn=_batch_scheduler
jonabur's avatar
jonabur committed
311
312
313
            if self.batch_size == "auto"
            and n_reordered_requests > 0
            and not override_bs
314
            else None,
Fabrizio Milo's avatar
Fabrizio Milo committed
315
        ):
316
            inps = []
317
            cont_toks_list = []
318
            inplens = []
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
334
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
335
                # gpt2    \               \
336
337
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice
338
339
340

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
Fabrizio Milo's avatar
Fabrizio Milo committed
341
342
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
343
                ).to(self.device)
Fabrizio Milo's avatar
Fabrizio Milo committed
344
                (inplen,) = inp.shape
345
346
347
348

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
Fabrizio Milo's avatar
Fabrizio Milo committed
349
350
351
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )
352

353
                # pad length from seq to padding_length
Fabrizio Milo's avatar
Fabrizio Milo committed
354
355
356
357
358
359
360
361
362
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )
363

364
365
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
366
                inplens.append(inplen)
367

368
            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length]
Fabrizio Milo's avatar
Fabrizio Milo committed
369
370
371
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]
372

373
374
            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
Fabrizio Milo's avatar
Fabrizio Milo committed
375
            ):
376

377
378
                # Slice to original seq length
                contlen = len(cont_toks)
jonabur's avatar
jonabur committed
379
380
381
                inplen = inplen + (
                    logits.shape[0] - padding_length
                )  # if "virtual tokens" (from prompt tuning) are added, inplen is larger
Fabrizio Milo's avatar
Fabrizio Milo committed
382
383
384
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]
385

386
                # Check if per-token argmax is exactly equal to continuation
387
                greedy_tokens = logits.argmax(dim=-1)
Fabrizio Milo's avatar
Fabrizio Milo committed
388
389
390
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
391
392
                max_equal = (greedy_tokens == cont_toks).all()

393
394
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
Fabrizio Milo's avatar
Fabrizio Milo committed
395
396
397
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]
398

399
                # Answer: (log prob, is-exact-match)
400
401
402
403
404
405
406
407
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

Fabrizio Milo's avatar
Fabrizio Milo committed
408
        return re_ord.get_original(res)
Fabrizio Milo's avatar
Fabrizio Milo committed
409

410
    def greedy_until(self, requests):
Fabrizio Milo's avatar
Fabrizio Milo committed
411
        # TODO: implement fully general `until` that handles until that are
412
        #       multiple tokens or that span multiple tokens correctly
413
414
415
416
417

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
418
419
420
421
422
423
424
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

425
            toks = self.tok_encode(x[0])
426
            return -len(toks), x[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
427

Fabrizio Milo's avatar
Fabrizio Milo committed
428
        re_ord = utils.Reorderer(requests, _collate)
429

430
        warn_stop_seq = False
431
        for context, request_args in tqdm(re_ord.get_reordered()):
432
            until = request_args["until"]
433
434
            if isinstance(until, str):
                until = [until]
435

436
            if until:
437
438
439
440
441
442
443
444
445
                try:
                    (primary_until,) = self.tok_encode(until[0])
                except ValueError:
                    if not warn_stop_seq:
                        print(
                            "Warning: a primary stop sequence is multi-token! Will default to EOS token for this tokenizer. Consider using `hf-causal-experimental` for multi-token stop sequence support for the time being."
                        )
                        warn_stop_seq = True
                    primary_until = self.eot_token_id
446
447
            else:
                primary_until = None
448

Fabrizio Milo's avatar
Fabrizio Milo committed
449
450
451
            context_enc = torch.tensor(
                [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
            ).to(self.device)
452

453
            max_gen_tokens = min(
454
                self.max_gen_toks, request_args.get("max_length", self.max_gen_toks)
455
            )
Fabrizio Milo's avatar
Fabrizio Milo committed
456
            cont = self._model_generate(
457
                context_enc, context_enc.shape[1] + max_gen_tokens, primary_until
Fabrizio Milo's avatar
Fabrizio Milo committed
458
459
460
            )

            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
461
462
463

            for term in until:
                s = s.split(term)[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
464

465
466
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
Fabrizio Milo's avatar
Fabrizio Milo committed
467

468
            res.append(s)
Fabrizio Milo's avatar
Fabrizio Milo committed
469

Fabrizio Milo's avatar
Fabrizio Milo committed
470
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
471

Leo Gao's avatar
Leo Gao committed
472

473
class Task(abc.ABC):
&'s avatar
&amp; committed
474
475
476
477
478
479
480
481
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Jonathan Tow's avatar
Jonathan Tow committed
482

Jon Tow's avatar
Jon Tow committed
483
484
    # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
    # or a path to a custom `datasets` loading script.
Jonathan Tow's avatar
Jonathan Tow committed
485
486
487
488
489
    DATASET_PATH: str = None

    # The name of a subset within `DATASET_PATH`.
    DATASET_NAME: str = None

Jon Tow's avatar
Jon Tow committed
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        """
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
        """
        self.download(data_dir, cache_dir, download_mode)
514
        self._training_docs = None
515
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
516

Jon Tow's avatar
Jon Tow committed
517
    def download(self, data_dir=None, cache_dir=None, download_mode=None):
Fabrizio Milo's avatar
Fabrizio Milo committed
518
        """Downloads and returns the task dataset.
Jonathan Tow's avatar
Jonathan Tow committed
519
520
        Override this method to download the dataset from a custom API.

Jon Tow's avatar
Jon Tow committed
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
Jonathan Tow's avatar
Jonathan Tow committed
541
542
543
544
        """
        self.dataset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
Jon Tow's avatar
Jon Tow committed
545
546
            data_dir=data_dir,
            cache_dir=cache_dir,
Fabrizio Milo's avatar
Fabrizio Milo committed
547
            download_mode=download_mode,
Jonathan Tow's avatar
Jonathan Tow committed
548
        )
sdtblck's avatar
sdtblck committed
549

550
551
552
553
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return False

554
    @abstractmethod
555
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
556
        """Whether the task has a training set"""
557
        pass
558

559
    @abstractmethod
560
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
561
562
563
        """Whether the task has a validation set"""
        pass

564
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
565
566
    def has_test_docs(self):
        """Whether the task has a test set"""
567
568
        pass

Leo Gao's avatar
Leo Gao committed
569
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
570
571
572
573
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
574
        return []
575

Leo Gao's avatar
Leo Gao committed
576
    def validation_docs(self):
577
578
579
580
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
581
        return []
582

Leo Gao's avatar
Leo Gao committed
583
    def test_docs(self):
584
585
586
587
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
588
        return []
Leo Gao's avatar
Leo Gao committed
589

Jon Tow's avatar
Jon Tow committed
590
591
592
593
    def _process_doc(self, doc):
        """
        Override this to process (detokenize, strip, replace, etc.) individual
        documents. This can be used in a map over documents of a data split.
Jon Tow's avatar
Jon Tow committed
594
        E.g. `map(self._process_doc, self.dataset["validation"])`
Jon Tow's avatar
Jon Tow committed
595
596
597
598
599
600

        :return: dict
            The processed version of the specified `doc`.
        """
        return doc

601
    def fewshot_examples(self, k, rnd):
602
603
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
604

Leo Gao's avatar
Leo Gao committed
605
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
606

607
    def doc_to_decontamination_query(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
608
609
610
611
        print(
            "Override doc_to_decontamination_query with document specific decontamination query."
        )
        assert False
612

613
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
614
615
616
    def doc_to_text(self, doc):
        pass

617
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
618
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
619
        pass
Leo Gao's avatar
Leo Gao committed
620

621
    @abstractmethod
622
    def construct_requests(self, doc, ctx):
Fabrizio Milo's avatar
Fabrizio Milo committed
623
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
624
625
        Requests which will be sent to the LM.

626
627
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
628
        :param ctx: str
Fabrizio Milo's avatar
Fabrizio Milo committed
629
            The context string, generated by fewshot_context. This includes the natural
630
            language description, as well as the few shot examples, and the question
Fabrizio Milo's avatar
Fabrizio Milo committed
631
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
632
        """
Leo Gao's avatar
Leo Gao committed
633
        pass
634

635
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
636
    def process_results(self, doc, results):
Fabrizio Milo's avatar
Fabrizio Milo committed
637
638
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
639
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
640
641
642
643
644

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
645
        """
Leo Gao's avatar
Leo Gao committed
646
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
647

648
    @abstractmethod
649
650
    def aggregation(self):
        """
&'s avatar
&amp; committed
651
        :returns: {str: [metric_score] -> float}
Fabrizio Milo's avatar
Fabrizio Milo committed
652
            A dictionary where keys are the names of submetrics and values are
&'s avatar
&amp; committed
653
            functions that aggregate a list of metric scores
654
655
656
        """
        pass

657
    @abstractmethod
658
659
660
    def higher_is_better(self):
        """
        :returns: {str: bool}
Fabrizio Milo's avatar
Fabrizio Milo committed
661
            A dictionary where keys are the names of submetrics and values are
662
663
664
665
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
666
    def fewshot_description(self):
667
        import warnings
Fabrizio Milo's avatar
Fabrizio Milo committed
668

669
        warnings.warn(
Jonathan Tow's avatar
Jonathan Tow committed
670
            "`fewshot_description` will be removed in futures versions. Pass "
671
            "any custom descriptions to the `evaluate` function instead.",
Fabrizio Milo's avatar
Fabrizio Milo committed
672
673
            DeprecationWarning,
        )
Jason Phang's avatar
checkin  
Jason Phang committed
674
675
        return ""

676
    @utils.positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
677
678
679
680
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        """Returns a fewshot context string that is made up of a prepended description
681
682
683
684
685
686
687
688
689
690
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
691
            WARNING: This is currently a required arg although it's optionalized with a default `None`.
692
693
694
695
696
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
Fabrizio Milo's avatar
Fabrizio Milo committed
697
698
699
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
700
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
701
            "The `provide_description` arg will be removed in future versions. To prepend "
702
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
703
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
704
        )
705
706
        if provide_description is not None:
            # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
707
708
709
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
710

711
        description = description + "\n\n" if description else ""
712

713
714
        if num_fewshot == 0:
            labeled_examples = ""
715
        else:
716
717
718
719
720
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
721
                    self._fewshot_docs = list(
Fabrizio Milo's avatar
Fabrizio Milo committed
722
723
724
                        self.validation_docs()
                        if self.has_validation_docs()
                        else self.test_docs()
725
                    )
726

727
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
728

729
730
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
731

Fabrizio Milo's avatar
Fabrizio Milo committed
732
733
734
735
736
737
738
739
740
            labeled_examples = (
                "\n\n".join(
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex
                    ]
                )
                + "\n\n"
            )
Leo Gao's avatar
Update  
Leo Gao committed
741

742
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
743
744
745
        return description + labeled_examples + example


Jon Tow's avatar
Jon Tow committed
746
class MultipleChoiceTask(Task):
Leo Gao's avatar
Leo Gao committed
747
    def doc_to_target(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
748
        return " " + doc["choices"][doc["gold"]]
Leo Gao's avatar
Leo Gao committed
749

Leo Gao's avatar
Leo Gao committed
750
751
    def construct_requests(self, doc, ctx):
        lls = [
Fabrizio Milo's avatar
Fabrizio Milo committed
752
            rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
Leo Gao's avatar
Leo Gao committed
753
754
755
756
757
758
759
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Fabrizio Milo's avatar
Fabrizio Milo committed
760
        acc = 1.0 if np.argmax(results) == gold else 0.0
761
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Fabrizio Milo's avatar
Fabrizio Milo committed
762
        acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
Leo Gao's avatar
Leo Gao committed
763
764

        return {
Leo Gao's avatar
Leo Gao committed
765
766
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
767
        }
Fabrizio Milo's avatar
Fabrizio Milo committed
768

Leo Gao's avatar
Leo Gao committed
769
770
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
771
772
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
773
        }
Fabrizio Milo's avatar
Fabrizio Milo committed
774

Leo Gao's avatar
Leo Gao committed
775
776
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
777
778
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
779
780
781
        }


Jason Phang's avatar
Jason Phang committed
782
class PerplexityTask(Task, abc.ABC):
783
784
785
786
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return True

Jason Phang's avatar
Jason Phang committed
787
788
789
790
791
792
793
    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

Fabrizio Milo's avatar
Fabrizio Milo committed
794
795
796
797
798
799
800
801
802
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        assert (
            num_fewshot == 0
        ), "The number of fewshot examples must be 0 for perplexity tasks."
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
803
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
804
            "The `provide_description` arg will be removed in future versions. To prepend "
Jonathan Tow's avatar
Jonathan Tow committed
805
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
806
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
807
        )
808
809
        if provide_description is not None:
            # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
810
811
812
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
813

Jason Phang's avatar
Jason Phang committed
814
815
816
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
817
818
819
820
821
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
822

823
824
825
    def doc_to_decontamination_query(self, doc):
        return doc

Jason Phang's avatar
Jason Phang committed
826
    def doc_to_text(self, doc):
827
        return ""
Jason Phang's avatar
Jason Phang committed
828
829

    def doc_to_target(self, doc):
830
        return doc
Jason Phang's avatar
Jason Phang committed
831
832
833

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
834
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
835
836
837
        return req

    def process_results(self, doc, results):
Fabrizio Milo's avatar
Fabrizio Milo committed
838
        (loglikelihood,) = results
Leo Gao's avatar
Leo Gao committed
839
        words = self.count_words(doc)
840
        bytes_ = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
841
        return {
Leo Gao's avatar
Leo Gao committed
842
            "word_perplexity": (loglikelihood, words),
843
            "byte_perplexity": (loglikelihood, bytes_),
844
            "bits_per_byte": (loglikelihood, bytes_),
Jason Phang's avatar
Jason Phang committed
845
846
847
848
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
849
850
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
851
            "bits_per_byte": bits_per_byte,
Jason Phang's avatar
Jason Phang committed
852
853
        }

854
855
    @classmethod
    def count_bytes(cls, doc):
Leo Gao's avatar
Leo Gao committed
856
        return len(doc.encode("utf-8"))
857
858
859

    @classmethod
    def count_words(cls, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
860
        """Downstream tasks with custom word boundaries should override this!"""
Leo Gao's avatar
Leo Gao committed
861
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
862

Jason Phang's avatar
Jason Phang committed
863

Leo Gao's avatar
Leo Gao committed
864
865
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
Fabrizio Milo's avatar
Fabrizio Milo committed
866
    return hashlib.sha256(dat.encode("utf-8")).hexdigest()
Leo Gao's avatar
Leo Gao committed
867
868


Leo Gao's avatar
Leo Gao committed
869
870
class CacheHook:
    def __init__(self, cachinglm):
Fabrizio Milo's avatar
Fabrizio Milo committed
871
        if cachinglm is None:
Leo Gao's avatar
Leo Gao committed
872
873
874
875
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
Fabrizio Milo's avatar
Fabrizio Milo committed
876

Leo Gao's avatar
Leo Gao committed
877
878
879
880
881
882
883
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
884
885
class CachingLM:
    def __init__(self, lm, cache_db):
886
887
888
889
890
891
892
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
893
        from sqlitedict import SqliteDict
Leo Gao's avatar
Leo Gao committed
894
895
        self.lm = lm
        self.cache_db = cache_db
896
897
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
898
899
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
900
901
902
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
903
    def __getattr__(self, attr):
gk's avatar
gk committed
904
905
906
907
        lm_attr = getattr(self.lm, attr)
        if not callable(lm_attr):
            return lm_attr

Leo Gao's avatar
Leo Gao committed
908
909
910
        def fn(requests):
            res = []
            remaining_reqs = []
Fabrizio Milo's avatar
Fabrizio Milo committed
911

Leo Gao's avatar
Leo Gao committed
912
913
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
914
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
915
916
917
918
919
920
921
922
923
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
Fabrizio Milo's avatar
Fabrizio Milo committed
924

925
            # actually run the LM on the requests that do not have cached results
Leo Gao's avatar
Leo Gao committed
926
927
928
929
930
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
931
932
                while res[resptr] is not None:
                    resptr += 1
Leo Gao's avatar
Leo Gao committed
933
934
935
936

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
937
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
938
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
939
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
940
941

            return res
Fabrizio Milo's avatar
Fabrizio Milo committed
942

Leo Gao's avatar
Leo Gao committed
943
        return fn
Fabrizio Milo's avatar
Fabrizio Milo committed
944

Leo Gao's avatar
Leo Gao committed
945
946
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
947

Jason Phang's avatar
Jason Phang committed
948

949
REQUEST_RETURN_LENGTHS = {
Fabrizio Milo's avatar
Fabrizio Milo committed
950
951
952
    "loglikelihood": 2,
    "greedy_until": None,
    "loglikelihood_rolling": None,
953
954
955
}


956
class Request:
Leo Gao's avatar
Leo Gao committed
957
958
    def __init__(self, request_type, args, index=None):
        if request_type not in REQUEST_RETURN_LENGTHS.keys():
Fabrizio Milo's avatar
Fabrizio Milo committed
959
960
961
            raise NotImplementedError(
                "The request type {} is not implemented!".format(request_type)
            )
Leo Gao's avatar
Leo Gao committed
962

Leo Gao's avatar
Leo Gao committed
963
        self.request_type = request_type
964
965
        self.args = args
        self.index = index
Fabrizio Milo's avatar
Fabrizio Milo committed
966

967
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
968
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Fabrizio Milo's avatar
Fabrizio Milo committed
969
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
970
971
        for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
            yield Request(self.request_type, self.args, i)
Fabrizio Milo's avatar
Fabrizio Milo committed
972

973
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
974
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Fabrizio Milo's avatar
Fabrizio Milo committed
975
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
976
        return Request(self.request_type, self.args, i)
Fabrizio Milo's avatar
Fabrizio Milo committed
977

Leo Gao's avatar
Leo Gao committed
978
    def __eq__(self, other):
Fabrizio Milo's avatar
Fabrizio Milo committed
979
980
981
982
983
        return (
            self.request_type == other.request_type
            and self.args == other.args
            and self.index == other.index
        )
Leo Gao's avatar
Leo Gao committed
984

Leo Gao's avatar
Leo Gao committed
985
    def __repr__(self):
Leo Gao's avatar
Leo Gao committed
986
        return f"Req_{self.request_type}{self.args}[{self.index}]\n"
987

Jason Phang's avatar
Jason Phang committed
988

Leo Gao's avatar
Leo Gao committed
989
990
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
991
992
        def fn(*args):
            return Request(attr, args)
Fabrizio Milo's avatar
Fabrizio Milo committed
993

Leo Gao's avatar
Leo Gao committed
994
995
996
997
        return fn


rf = RequestFactory()