base.py 32.8 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import abc
2
from typing import Iterable
thefazzer's avatar
thefazzer committed
3
import numpy as np
4
import random
Leo Gao's avatar
Leo Gao committed
5
import re
6
7
8
import os
import json
import hashlib
Jonathan Tow's avatar
Jonathan Tow committed
9
import datasets
10
from sqlitedict import SqliteDict
11
from tqdm import tqdm
12
import torch
Leo Gao's avatar
Leo Gao committed
13
import torch.nn.functional as F
bzantium's avatar
bzantium committed
14
from accelerate import find_executable_batch_size
&'s avatar
& committed
15

16
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
17
from lm_eval import utils
18
from abc import abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
19

Jason Phang's avatar
Jason Phang committed
20

Leo Gao's avatar
Leo Gao committed
21
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
22
23
24
    def __init__(self):
        self.cache_hook = CacheHook(None)

25
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
26
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
27
        """Compute log-likelihood of generating a continuation from a context.
bzantium's avatar
bzantium committed
28
        Downstream tasks should attempt to use loglikelihood instead of other
Leo Gao's avatar
Leo Gao committed
29
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
30

Leo Gao's avatar
Leo Gao committed
31
32
33
        :param requests: list
            A list of pairs (context, continuation)
            context: str
bzantium's avatar
bzantium committed
34
                Context string. Implementations of LM must be able to handle an
Leo Gao's avatar
Leo Gao committed
35
                empty context string.
Leo Gao's avatar
Leo Gao committed
36
            continuation: str
bzantium's avatar
bzantium committed
37
38
                The continuation over which log likelihood will be calculated. If
                there is a word boundary, the space should be in the continuation.
Leo Gao's avatar
Leo Gao committed
39
40
41
42
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
43
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
44
            isgreedy:
Jason Phang's avatar
Jason Phang committed
45
46
47
48
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

49
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
50
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
51
52
53
54
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
bzantium's avatar
bzantium committed
55
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
Jason Phang's avatar
Jason Phang committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
88
89
90
        """
        pass

&'s avatar
& committed
91
    # TODO: Add an optional max length
92
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
93
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
94
95
96
97
98
99
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
100
            until: [str]
bzantium's avatar
bzantium committed
101
                The string sequences to generate until. These string sequences
Leo Gao's avatar
Leo Gao committed
102
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
103
104
105
106
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
107
        """
Leo Gao's avatar
Leo Gao committed
108
109
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
110
    @classmethod
111
112
    def create_from_arg_string(cls, arg_string, additional_config=None):
        additional_config = {} if additional_config is None else additional_config
113
114
115
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
116

Leo Gao's avatar
Leo Gao committed
117
118
119
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
120

121
class BaseLM(LM):
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    @property
    @abstractmethod
    def eot_token_id(self):
        pass

    @property
    @abstractmethod
    def max_length(self):
        pass

    @property
    @abstractmethod
    def max_gen_toks(self):
        pass

    @property
    @abstractmethod
    def batch_size(self):
        pass

    @property
    @abstractmethod
    def device(self):
        pass

147
    @abstractmethod
bzantium's avatar
bzantium committed
148
149
150
    def tok_encode(self, string: str):
        pass

151
    @abstractmethod
bzantium's avatar
bzantium committed
152
153
    def tok_decode(self, tokens: Iterable[int]):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
154

155
    @abstractmethod
bzantium's avatar
bzantium committed
156
157
    def _model_generate(self, context, max_length, eos_token_id):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
158

159
160
    @abstractmethod
    def _model_call(self, inps):
Jason Phang's avatar
gpt3  
Jason Phang committed
161
        """
162
163
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Jason Phang's avatar
gpt3  
Jason Phang committed
164

165
        returns: a torch tensor of shape [batch, sequence, vocab] with the
166
        logits returned from the model
167
168
        """
        pass
169

Leo Gao's avatar
Leo Gao committed
170
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    # TODO: enforce this somehow

    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
bzantium's avatar
bzantium committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

        # automatic batch size detection for vectorization
        adaptive_batch_size = None
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")

            @find_executable_batch_size(
                starting_batch_size=512
            )  # if OOM, then halves batch_size and tries again
            def forward_batch(batch_size):
                test_batch = torch.ones(
                    (batch_size, self.max_length), device=self.device
                ).long()
                for _ in range(5):
                    _ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
                return batch_size

            batch_size = forward_batch()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size
211
212

        loglikelihoods = []
bzantium's avatar
bzantium committed
213
214
215
216
217
218
219
220
221
222
223
224
        for (string,) in tqdm(requests):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
225
226
227

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

228
229
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
bzantium's avatar
bzantium committed
230
231
232
233
234
235
            string_nll = self._loglikelihood_tokens(
                rolling_token_windows,
                disable_tqdm=True,
                override_bs=adaptive_batch_size,
            )

236
237
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
bzantium's avatar
bzantium committed
238

239
240
241
242
243
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

bzantium's avatar
bzantium committed
244
    def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs=None):
245
246
247
248
249
250
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
251
252
253
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
254
255
256
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
257
            return -len(toks), tuple(toks)
bzantium's avatar
bzantium committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

        re_ord = utils.Reorderer(requests, _collate)

        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
        if len(re_ord.get_reordered()) > 0:
            _, context_enc, continuation_enc = re_ord.get_reordered()[0]
            max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
            if (self.batch_size == 'auto'):
                
                if override_bs is None:
                    print('Passed argument batch_size = auto. Detecting largest batch size')
                    @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
                    def forward_batch(batch_size):
                        test_batch = torch.ones((batch_size, max_context), device=self.device).long()
                        for _ in range(5):
                            out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
                        return batch_size

                    batch_size = forward_batch()
                    print(f"Determined largest batch size: {batch_size}")
                    adaptive_batch_size = batch_size

                else:
                    adaptive_batch_size = override_bs
        else:
            adaptive_batch_size = 0 if override_bs is None else override_bs

        for chunk in utils.chunks(
            tqdm(re_ord.get_reordered(), disable=disable_tqdm),
            self.batch_size if self.batch_size != "auto" else adaptive_batch_size,
        ):
290
            inps = []
291
            cont_toks_list = []
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
308
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
309
                # gpt2    \               \
310
311
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice
312
313
314

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
bzantium's avatar
bzantium committed
315
316
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
317
                ).to(self.device)
bzantium's avatar
bzantium committed
318
                (inplen,) = inp.shape
319
320
321
322

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
bzantium's avatar
bzantium committed
323
324
325
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )
326

327
                # pad length from seq to padding_length
bzantium's avatar
bzantium committed
328
329
330
331
332
333
334
335
336
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )
337

338
339
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
340
341
                inplens.append(inplen)

342
            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
bzantium's avatar
bzantium committed
343
344
345
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]
346

bzantium's avatar
bzantium committed
347
348
349
            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
            ):
350

351
352
                # Slice to original seq length
                contlen = len(cont_toks)
bzantium's avatar
bzantium committed
353
354
355
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]
356

357
                # Check if per-token argmax is exactly equal to continuation
358
                greedy_tokens = logits.argmax(dim=-1)
bzantium's avatar
bzantium committed
359
360
361
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
362
363
                max_equal = (greedy_tokens == cont_toks).all()

364
365
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
bzantium's avatar
bzantium committed
366
367
368
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]
369

370
                # Answer: (log prob, is-exact-match)
371
372
373
374
375
376
377
378
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

bzantium's avatar
bzantium committed
379
380
        return re_ord.get_original(res)

381
    def greedy_until(self, requests):
bzantium's avatar
bzantium committed
382
        # TODO: implement fully general `until` that handles until that are
383
        #       multiple tokens or that span multiple tokens correctly
384
385
386
387
388
389

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
390
            return len(toks), x[0]
391

bzantium's avatar
bzantium committed
392
393
394
395
        re_ord = utils.Reorderer(requests, _collate)

        for context, request_args in tqdm(re_ord.get_reordered()):
            until = request_args["until"]
396
397
            if isinstance(until, str):
                until = [until]
398

bzantium's avatar
bzantium committed
399
400
401
402
403
404
405
406
            if until:
                (primary_until,) = self.tok_encode(until[0])
            else:
                primary_until = None

            context_enc = torch.tensor(
                [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
            ).to(self.device)
407

bzantium's avatar
bzantium committed
408
409
410
411
412
413
            max_gen_tokens = min(
                self.max_gen_toks, request_args.get("max_length", self.max_gen_toks)
            )
            cont = self._model_generate(
                context_enc, context_enc.shape[1] + max_gen_tokens, primary_until
            )
414

bzantium's avatar
bzantium committed
415
            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
416
417
418

            for term in until:
                s = s.split(term)[0]
bzantium's avatar
bzantium committed
419

420
421
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
bzantium's avatar
bzantium committed
422

423
            res.append(s)
bzantium's avatar
bzantium committed
424
425

        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
426

Leo Gao's avatar
Leo Gao committed
427

428
class Task(abc.ABC):
&'s avatar
&amp; committed
429
430
431
432
433
434
435
436
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Jonathan Tow's avatar
Jonathan Tow committed
437

Jon Tow's avatar
Jon Tow committed
438
439
    # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
    # or a path to a custom `datasets` loading script.
Jonathan Tow's avatar
Jonathan Tow committed
440
441
442
443
444
    DATASET_PATH: str = None

    # The name of a subset within `DATASET_PATH`.
    DATASET_NAME: str = None

Jon Tow's avatar
Jon Tow committed
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        """
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
        """
        self.download(data_dir, cache_dir, download_mode)
469
        self._training_docs = None
470
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
471

Jon Tow's avatar
Jon Tow committed
472
    def download(self, data_dir=None, cache_dir=None, download_mode=None):
bzantium's avatar
bzantium committed
473
        """Downloads and returns the task dataset.
Jonathan Tow's avatar
Jonathan Tow committed
474
475
        Override this method to download the dataset from a custom API.

Jon Tow's avatar
Jon Tow committed
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
Jonathan Tow's avatar
Jonathan Tow committed
496
497
498
499
        """
        self.dataset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
Jon Tow's avatar
Jon Tow committed
500
501
            data_dir=data_dir,
            cache_dir=cache_dir,
bzantium's avatar
bzantium committed
502
            download_mode=download_mode,
Jonathan Tow's avatar
Jonathan Tow committed
503
        )
sdtblck's avatar
sdtblck committed
504

bzantium's avatar
bzantium committed
505
506
507
508
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return False

509
    @abstractmethod
510
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
511
        """Whether the task has a training set"""
512
        pass
513

514
    @abstractmethod
515
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
516
517
518
        """Whether the task has a validation set"""
        pass

519
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
520
521
    def has_test_docs(self):
        """Whether the task has a test set"""
522
523
        pass

Leo Gao's avatar
Leo Gao committed
524
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
525
526
527
528
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
529
        return []
530

Leo Gao's avatar
Leo Gao committed
531
    def validation_docs(self):
532
533
534
535
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
536
        return []
537

Leo Gao's avatar
Leo Gao committed
538
    def test_docs(self):
539
540
541
542
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
543
        return []
Leo Gao's avatar
Leo Gao committed
544

Jon Tow's avatar
Jon Tow committed
545
546
547
548
    def _process_doc(self, doc):
        """
        Override this to process (detokenize, strip, replace, etc.) individual
        documents. This can be used in a map over documents of a data split.
Jon Tow's avatar
Jon Tow committed
549
        E.g. `map(self._process_doc, self.dataset["validation"])`
Jon Tow's avatar
Jon Tow committed
550
551
552
553
554
555

        :return: dict
            The processed version of the specified `doc`.
        """
        return doc

556
    def fewshot_examples(self, k, rnd):
557
558
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
559

Leo Gao's avatar
Leo Gao committed
560
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
561

bzantium's avatar
bzantium committed
562
563
564
565
566
567
    def doc_to_decontamination_query(self, doc):
        print(
            "Override doc_to_decontamination_query with document specific decontamination query."
        )
        assert False

568
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
569
570
571
    def doc_to_text(self, doc):
        pass

572
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
573
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
574
        pass
Leo Gao's avatar
Leo Gao committed
575

576
    @abstractmethod
577
    def construct_requests(self, doc, ctx):
bzantium's avatar
bzantium committed
578
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
579
580
        Requests which will be sent to the LM.

581
582
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
583
        :param ctx: str
bzantium's avatar
bzantium committed
584
            The context string, generated by fewshot_context. This includes the natural
585
            language description, as well as the few shot examples, and the question
bzantium's avatar
bzantium committed
586
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
587
        """
Leo Gao's avatar
Leo Gao committed
588
        pass
589

590
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
591
    def process_results(self, doc, results):
bzantium's avatar
bzantium committed
592
593
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
594
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
595
596
597
598
599

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
600
        """
Leo Gao's avatar
Leo Gao committed
601
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
602

603
    @abstractmethod
604
605
    def aggregation(self):
        """
&'s avatar
&amp; committed
606
        :returns: {str: [metric_score] -> float}
bzantium's avatar
bzantium committed
607
            A dictionary where keys are the names of submetrics and values are
&'s avatar
&amp; committed
608
            functions that aggregate a list of metric scores
609
610
611
        """
        pass

612
    @abstractmethod
613
614
615
    def higher_is_better(self):
        """
        :returns: {str: bool}
bzantium's avatar
bzantium committed
616
            A dictionary where keys are the names of submetrics and values are
617
618
619
620
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
621
    def fewshot_description(self):
622
        import warnings
bzantium's avatar
bzantium committed
623

624
        warnings.warn(
Jonathan Tow's avatar
Jonathan Tow committed
625
            "`fewshot_description` will be removed in futures versions. Pass "
626
            "any custom descriptions to the `evaluate` function instead.",
bzantium's avatar
bzantium committed
627
628
            DeprecationWarning,
        )
Jason Phang's avatar
checkin  
Jason Phang committed
629
630
        return ""

631
    @utils.positional_deprecated
bzantium's avatar
bzantium committed
632
633
634
635
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        """Returns a fewshot context string that is made up of a prepended description
636
637
638
639
640
641
642
643
644
645
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
646
            WARNING: This is currently a required arg although it's optionalized with a default `None`.
647
648
649
650
651
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
bzantium's avatar
bzantium committed
652
653
654
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
655
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
656
            "The `provide_description` arg will be removed in future versions. To prepend "
657
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
658
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
659
        )
660
661
        if provide_description is not None:
            # nudge people to not specify it at all
bzantium's avatar
bzantium committed
662
663
664
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
665

666
        description = description + "\n\n" if description else ""
667

668
669
        if num_fewshot == 0:
            labeled_examples = ""
670
        else:
671
672
673
674
675
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
676
                    self._fewshot_docs = list(
bzantium's avatar
bzantium committed
677
678
679
                        self.validation_docs()
                        if self.has_validation_docs()
                        else self.test_docs()
680
                    )
681

682
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
683

684
685
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
686

bzantium's avatar
bzantium committed
687
688
689
690
691
692
693
694
695
            labeled_examples = (
                "\n\n".join(
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex
                    ]
                )
                + "\n\n"
            )
Leo Gao's avatar
Update  
Leo Gao committed
696

697
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
698
699
700
        return description + labeled_examples + example


Jon Tow's avatar
Jon Tow committed
701
class MultipleChoiceTask(Task):
Leo Gao's avatar
Leo Gao committed
702
    def doc_to_target(self, doc):
bzantium's avatar
bzantium committed
703
        return " " + doc["choices"][doc["gold"]]
Leo Gao's avatar
Leo Gao committed
704

Leo Gao's avatar
Leo Gao committed
705
706
    def construct_requests(self, doc, ctx):
        lls = [
bzantium's avatar
bzantium committed
707
            rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
Leo Gao's avatar
Leo Gao committed
708
709
710
711
712
713
714
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

bzantium's avatar
bzantium committed
715
        acc = 1.0 if np.argmax(results) == gold else 0.0
716
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
bzantium's avatar
bzantium committed
717
        acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
Leo Gao's avatar
Leo Gao committed
718
719

        return {
Leo Gao's avatar
Leo Gao committed
720
721
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
722
        }
bzantium's avatar
bzantium committed
723

Leo Gao's avatar
Leo Gao committed
724
725
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
726
727
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
728
        }
bzantium's avatar
bzantium committed
729

Leo Gao's avatar
Leo Gao committed
730
731
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
732
733
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
734
735
736
        }


Jason Phang's avatar
Jason Phang committed
737
class PerplexityTask(Task, abc.ABC):
bzantium's avatar
bzantium committed
738
739
740
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return True
Jason Phang's avatar
Jason Phang committed
741
742
743
744
745
746
747
748

    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

bzantium's avatar
bzantium committed
749
750
751
752
753
754
755
756
757
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        assert (
            num_fewshot == 0
        ), "The number of fewshot examples must be 0 for perplexity tasks."
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
758
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
759
            "The `provide_description` arg will be removed in future versions. To prepend "
Jonathan Tow's avatar
Jonathan Tow committed
760
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
761
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
762
        )
763
764
        if provide_description is not None:
            # nudge people to not specify it at all
bzantium's avatar
bzantium committed
765
766
767
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
768

Jason Phang's avatar
Jason Phang committed
769
770
771
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
772
773
774
775
776
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
777

bzantium's avatar
bzantium committed
778
779
780
    def doc_to_decontamination_query(self, doc):
        return doc

Jason Phang's avatar
Jason Phang committed
781
    def doc_to_text(self, doc):
782
        return ""
Jason Phang's avatar
Jason Phang committed
783
784

    def doc_to_target(self, doc):
785
        return doc
Jason Phang's avatar
Jason Phang committed
786
787
788

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
789
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
790
791
792
        return req

    def process_results(self, doc, results):
bzantium's avatar
bzantium committed
793
        (loglikelihood,) = results
Leo Gao's avatar
Leo Gao committed
794
        words = self.count_words(doc)
795
        bytes_ = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
796
        return {
Leo Gao's avatar
Leo Gao committed
797
            "word_perplexity": (loglikelihood, words),
798
            "byte_perplexity": (loglikelihood, bytes_),
799
            "bits_per_byte": (loglikelihood, bytes_),
Jason Phang's avatar
Jason Phang committed
800
801
802
803
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
804
805
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
806
            "bits_per_byte": bits_per_byte,
Jason Phang's avatar
Jason Phang committed
807
808
        }

809
810
    @classmethod
    def count_bytes(cls, doc):
Leo Gao's avatar
Leo Gao committed
811
        return len(doc.encode("utf-8"))
812
813
814

    @classmethod
    def count_words(cls, doc):
bzantium's avatar
bzantium committed
815
        """Downstream tasks with custom word boundaries should override this!"""
Leo Gao's avatar
Leo Gao committed
816
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
817

Jason Phang's avatar
Jason Phang committed
818

Leo Gao's avatar
Leo Gao committed
819
820
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
bzantium's avatar
bzantium committed
821
    return hashlib.sha256(dat.encode("utf-8")).hexdigest()
Leo Gao's avatar
Leo Gao committed
822
823


Leo Gao's avatar
Leo Gao committed
824
825
class CacheHook:
    def __init__(self, cachinglm):
bzantium's avatar
bzantium committed
826
        if cachinglm is None:
Leo Gao's avatar
Leo Gao committed
827
828
829
830
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
bzantium's avatar
bzantium committed
831

Leo Gao's avatar
Leo Gao committed
832
833
834
835
836
837
838
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
839
840
class CachingLM:
    def __init__(self, lm, cache_db):
841
842
843
844
845
846
847
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
Leo Gao's avatar
Leo Gao committed
848
849
        self.lm = lm
        self.cache_db = cache_db
850
851
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
852
853
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
854
855
856
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
857
858
859
860
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
bzantium's avatar
bzantium committed
861

Leo Gao's avatar
Leo Gao committed
862
863
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
864
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
865
866
867
868
869
870
871
872
873
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
bzantium's avatar
bzantium committed
874

875
            # actually run the LM on the requests that do not have cached results
Leo Gao's avatar
Leo Gao committed
876
877
878
879
880
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
881
882
                while res[resptr] is not None:
                    resptr += 1
Leo Gao's avatar
Leo Gao committed
883
884
885
886

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
887
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
888
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
889
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
890
891

            return res
bzantium's avatar
bzantium committed
892

Leo Gao's avatar
Leo Gao committed
893
        return fn
bzantium's avatar
bzantium committed
894

Leo Gao's avatar
Leo Gao committed
895
896
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
897

Jason Phang's avatar
Jason Phang committed
898

899
REQUEST_RETURN_LENGTHS = {
bzantium's avatar
bzantium committed
900
901
902
    "loglikelihood": 2,
    "greedy_until": None,
    "loglikelihood_rolling": None,
903
904
905
}


906
class Request:
Leo Gao's avatar
Leo Gao committed
907
908
    def __init__(self, request_type, args, index=None):
        if request_type not in REQUEST_RETURN_LENGTHS.keys():
bzantium's avatar
bzantium committed
909
910
911
            raise NotImplementedError(
                "The request type {} is not implemented!".format(request_type)
            )
Leo Gao's avatar
Leo Gao committed
912

Leo Gao's avatar
Leo Gao committed
913
        self.request_type = request_type
914
915
        self.args = args
        self.index = index
bzantium's avatar
bzantium committed
916

917
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
918
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
bzantium's avatar
bzantium committed
919
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
920
921
        for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
            yield Request(self.request_type, self.args, i)
bzantium's avatar
bzantium committed
922

923
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
924
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
bzantium's avatar
bzantium committed
925
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
926
        return Request(self.request_type, self.args, i)
bzantium's avatar
bzantium committed
927

Leo Gao's avatar
Leo Gao committed
928
    def __eq__(self, other):
bzantium's avatar
bzantium committed
929
930
931
932
933
        return (
            self.request_type == other.request_type
            and self.args == other.args
            and self.index == other.index
        )
Leo Gao's avatar
Leo Gao committed
934

Leo Gao's avatar
Leo Gao committed
935
    def __repr__(self):
Leo Gao's avatar
Leo Gao committed
936
        return f"Req_{self.request_type}{self.args}[{self.index}]\n"
937

Jason Phang's avatar
Jason Phang committed
938

Leo Gao's avatar
Leo Gao committed
939
940
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
941
942
        def fn(*args):
            return Request(attr, args)
bzantium's avatar
bzantium committed
943

Leo Gao's avatar
Leo Gao committed
944
945
946
947
        return fn


rf = RequestFactory()