base.py 31.7 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import abc
2
from typing import Iterable
thefazzer's avatar
thefazzer committed
3
import numpy as np
4
import random
Leo Gao's avatar
Leo Gao committed
5
import re
6
7
8
import os
import json
import hashlib
Jonathan Tow's avatar
Jonathan Tow committed
9
import datasets
10
from sqlitedict import SqliteDict
11
from tqdm import tqdm
12
import torch
Leo Gao's avatar
Leo Gao committed
13
import torch.nn.functional as F
&'s avatar
& committed
14

15
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
16
from lm_eval import utils
17
from abc import abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
18

Jason Phang's avatar
Jason Phang committed
19

Leo Gao's avatar
Leo Gao committed
20
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
21
22
23
    def __init__(self):
        self.cache_hook = CacheHook(None)

24
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
25
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
26
        """Compute log-likelihood of generating a continuation from a context.
Fabrizio Milo's avatar
Fabrizio Milo committed
27
        Downstream tasks should attempt to use loglikelihood instead of other
Leo Gao's avatar
Leo Gao committed
28
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
29

Leo Gao's avatar
Leo Gao committed
30
31
32
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Fabrizio Milo's avatar
Fabrizio Milo committed
33
                Context string. Implementations of LM must be able to handle an
Leo Gao's avatar
Leo Gao committed
34
                empty context string.
Leo Gao's avatar
Leo Gao committed
35
            continuation: str
Fabrizio Milo's avatar
Fabrizio Milo committed
36
37
                The continuation over which log likelihood will be calculated. If
                there is a word boundary, the space should be in the continuation.
Leo Gao's avatar
Leo Gao committed
38
39
40
41
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
42
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
43
            isgreedy:
Jason Phang's avatar
Jason Phang committed
44
45
46
47
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

48
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
49
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
50
51
52
53
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
Fabrizio Milo's avatar
Fabrizio Milo committed
54
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
Jason Phang's avatar
Jason Phang committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
87
88
89
        """
        pass

&'s avatar
& committed
90
    # TODO: Add an optional max length
91
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
92
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
93
94
95
96
97
98
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
99
            until: [str]
Fabrizio Milo's avatar
Fabrizio Milo committed
100
                The string sequences to generate until. These string sequences
Leo Gao's avatar
Leo Gao committed
101
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
102
103
104
105
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
106
        """
Leo Gao's avatar
Leo Gao committed
107
108
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
109
    @classmethod
110
111
    def create_from_arg_string(cls, arg_string, additional_config=None):
        additional_config = {} if additional_config is None else additional_config
112
113
114
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
115

Leo Gao's avatar
Leo Gao committed
116
117
118
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
119

120
class BaseLM(LM):
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    @property
    @abstractmethod
    def eot_token_id(self):
        pass

    @property
    @abstractmethod
    def max_length(self):
        pass

    @property
    @abstractmethod
    def max_gen_toks(self):
        pass

    @property
    @abstractmethod
    def batch_size(self):
        pass

    @property
    @abstractmethod
    def device(self):
        pass

146
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
147
148
149
    def tok_encode(self, string: str):
        pass

150
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
151
152
    def tok_decode(self, tokens: Iterable[int]):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
153

154
    @abstractmethod
155
    def _model_generate(self, context, max_length, eos_token_id, temperature=0.):
Fabrizio Milo's avatar
Fabrizio Milo committed
156
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
157

158
159
    @abstractmethod
    def _model_call(self, inps):
Jason Phang's avatar
gpt3  
Jason Phang committed
160
        """
161
162
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Jason Phang's avatar
gpt3  
Jason Phang committed
163

164
        returns: a torch tensor of shape [batch, sequence, vocab] with the
165
        logits returned from the model
166
167
        """
        pass
168

Leo Gao's avatar
Leo Gao committed
169
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    # TODO: enforce this somehow

    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
Fabrizio Milo's avatar
Fabrizio Milo committed
192
193
194
195
196
197
198
199
200
201
202
203
        for (string,) in tqdm(requests):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
204
205
206

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

207
208
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
Fabrizio Milo's avatar
Fabrizio Milo committed
209
210
211
212
            string_nll = self._loglikelihood_tokens(
                rolling_token_windows, disable_tqdm=True
            )

213
214
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
Fabrizio Milo's avatar
Fabrizio Milo committed
215

216
217
218
219
220
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

221
222
223
224
225
226
227
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
228
229
230
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
231
232
233
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
234
            return -len(toks), tuple(toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
235

236
        # TODO: automatic (variable) batch size detection for vectorization
Fabrizio Milo's avatar
Fabrizio Milo committed
237
        re_ord = utils.Reorderer(requests, _collate)
Fabrizio Milo's avatar
Fabrizio Milo committed
238
        for chunk in utils.chunks(
Fabrizio Milo's avatar
Fabrizio Milo committed
239
            tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size
Fabrizio Milo's avatar
Fabrizio Milo committed
240
        ):
241
            inps = []
242
            cont_toks_list = []
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
259
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
260
                # gpt2    \               \
261
262
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice
263
264
265

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
Fabrizio Milo's avatar
Fabrizio Milo committed
266
267
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
268
                ).to(self.device)
Fabrizio Milo's avatar
Fabrizio Milo committed
269
                (inplen,) = inp.shape
270
271
272
273

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
Fabrizio Milo's avatar
Fabrizio Milo committed
274
275
276
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )
277

278
                # pad length from seq to padding_length
Fabrizio Milo's avatar
Fabrizio Milo committed
279
280
281
282
283
284
285
286
287
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )
288

289
290
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
291
292
                inplens.append(inplen)

293
            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
Fabrizio Milo's avatar
Fabrizio Milo committed
294
295
296
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]
297

Fabrizio Milo's avatar
Fabrizio Milo committed
298
299
300
            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
            ):
301

302
303
                # Slice to original seq length
                contlen = len(cont_toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
304
305
306
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]
307

308
                # Check if per-token argmax is exactly equal to continuation
309
                greedy_tokens = logits.argmax(dim=-1)
Fabrizio Milo's avatar
Fabrizio Milo committed
310
311
312
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
313
314
                max_equal = (greedy_tokens == cont_toks).all()

315
316
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
Fabrizio Milo's avatar
Fabrizio Milo committed
317
318
319
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]
320

321
                # Answer: (log prob, is-exact-match)
322
323
324
325
326
327
328
329
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

Fabrizio Milo's avatar
Fabrizio Milo committed
330
        return re_ord.get_original(res)
Fabrizio Milo's avatar
Fabrizio Milo committed
331

332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    def multiple_temperature_sample_until(self, requests, k=32, temperature=0.3):
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
            return len(toks), x[0]

        re_ord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(re_ord.get_reordered()):
            if isinstance(until, str):
                until = [until]

            (primary_until,) = self.tok_encode(until[0])

            context_enc = torch.tensor(
                [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
            ).to(self.device)
            assert context_enc.shape[0] == 1

            context_enc = context_enc.expand(k, context_enc.shape[1])
            cont = self._model_generate(
                context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until,
                temperature=temperature
            )
            
            generated_tokens = cont[:, context_enc.shape[1]:]
            s = [self.tok_decode(candidate) for candidate in generated_tokens]
            for term in until:
                s = [candidate.split(term)[0] for candidate in s]

            # partial caching
            self.cache_hook.add_partial("multiple_temperature_sample_until", (context, until, k, temperature), s)

            res.append(s)
        return re_ord.get_original(res)


370
    def greedy_until(self, requests):
Fabrizio Milo's avatar
Fabrizio Milo committed
371
        # TODO: implement fully general `until` that handles until that are
372
        #       multiple tokens or that span multiple tokens correctly
373
374
375
376
377
378

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
379
            return len(toks), x[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
380

Fabrizio Milo's avatar
Fabrizio Milo committed
381
        re_ord = utils.Reorderer(requests, _collate)
382

Fabrizio Milo's avatar
Fabrizio Milo committed
383
        for context, until in tqdm(re_ord.get_reordered()):
384
385
            if isinstance(until, str):
                until = [until]
386

Fabrizio Milo's avatar
Fabrizio Milo committed
387
            (primary_until,) = self.tok_encode(until[0])
388

Fabrizio Milo's avatar
Fabrizio Milo committed
389
390
391
            context_enc = torch.tensor(
                [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
            ).to(self.device)
392

Fabrizio Milo's avatar
Fabrizio Milo committed
393
394
395
396
397
            cont = self._model_generate(
                context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
            )

            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
398
399
400

            for term in until:
                s = s.split(term)[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
401

402
403
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
Fabrizio Milo's avatar
Fabrizio Milo committed
404

405
            res.append(s)
Fabrizio Milo's avatar
Fabrizio Milo committed
406

Fabrizio Milo's avatar
Fabrizio Milo committed
407
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
408

Leo Gao's avatar
Leo Gao committed
409

410
class Task(abc.ABC):
&'s avatar
&amp; committed
411
412
413
414
415
416
417
418
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Jonathan Tow's avatar
Jonathan Tow committed
419

Jon Tow's avatar
Jon Tow committed
420
421
    # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
    # or a path to a custom `datasets` loading script.
Jonathan Tow's avatar
Jonathan Tow committed
422
423
424
425
426
    DATASET_PATH: str = None

    # The name of a subset within `DATASET_PATH`.
    DATASET_NAME: str = None

Jon Tow's avatar
Jon Tow committed
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        """
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
        """
        self.download(data_dir, cache_dir, download_mode)
451
        self._training_docs = None
452
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
453

Jon Tow's avatar
Jon Tow committed
454
    def download(self, data_dir=None, cache_dir=None, download_mode=None):
Fabrizio Milo's avatar
Fabrizio Milo committed
455
        """Downloads and returns the task dataset.
Jonathan Tow's avatar
Jonathan Tow committed
456
457
        Override this method to download the dataset from a custom API.

Jon Tow's avatar
Jon Tow committed
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
Jonathan Tow's avatar
Jonathan Tow committed
478
479
480
481
        """
        self.dataset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
Jon Tow's avatar
Jon Tow committed
482
483
            data_dir=data_dir,
            cache_dir=cache_dir,
Fabrizio Milo's avatar
Fabrizio Milo committed
484
            download_mode=download_mode,
Jonathan Tow's avatar
Jonathan Tow committed
485
        )
sdtblck's avatar
sdtblck committed
486

487
488
489
490
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return False

491
    @abstractmethod
492
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
493
        """Whether the task has a training set"""
494
        pass
495

496
    @abstractmethod
497
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
498
499
500
        """Whether the task has a validation set"""
        pass

501
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
502
503
    def has_test_docs(self):
        """Whether the task has a test set"""
504
505
        pass

Leo Gao's avatar
Leo Gao committed
506
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
507
508
509
510
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
511
        return []
512

Leo Gao's avatar
Leo Gao committed
513
    def validation_docs(self):
514
515
516
517
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
518
        return []
519

Leo Gao's avatar
Leo Gao committed
520
    def test_docs(self):
521
522
523
524
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
525
        return []
Leo Gao's avatar
Leo Gao committed
526

Jon Tow's avatar
Jon Tow committed
527
528
529
530
    def _process_doc(self, doc):
        """
        Override this to process (detokenize, strip, replace, etc.) individual
        documents. This can be used in a map over documents of a data split.
Jon Tow's avatar
Jon Tow committed
531
        E.g. `map(self._process_doc, self.dataset["validation"])`
Jon Tow's avatar
Jon Tow committed
532
533
534
535
536
537

        :return: dict
            The processed version of the specified `doc`.
        """
        return doc

538
    def fewshot_examples(self, k, rnd):
539
540
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
541

Leo Gao's avatar
Leo Gao committed
542
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
543

544
    def doc_to_decontamination_query(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
545
546
547
548
        print(
            "Override doc_to_decontamination_query with document specific decontamination query."
        )
        assert False
549

550
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
551
552
553
    def doc_to_text(self, doc):
        pass

554
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
555
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
556
        pass
Leo Gao's avatar
Leo Gao committed
557

558
    @abstractmethod
559
    def construct_requests(self, doc, ctx):
Fabrizio Milo's avatar
Fabrizio Milo committed
560
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
561
562
        Requests which will be sent to the LM.

563
564
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
565
        :param ctx: str
Fabrizio Milo's avatar
Fabrizio Milo committed
566
            The context string, generated by fewshot_context. This includes the natural
567
            language description, as well as the few shot examples, and the question
Fabrizio Milo's avatar
Fabrizio Milo committed
568
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
569
        """
Leo Gao's avatar
Leo Gao committed
570
        pass
571

572
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
573
    def process_results(self, doc, results):
Fabrizio Milo's avatar
Fabrizio Milo committed
574
575
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
576
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
577
578
579
580
581

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
582
        """
Leo Gao's avatar
Leo Gao committed
583
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
584

585
    @abstractmethod
586
587
    def aggregation(self):
        """
&'s avatar
&amp; committed
588
        :returns: {str: [metric_score] -> float}
Fabrizio Milo's avatar
Fabrizio Milo committed
589
            A dictionary where keys are the names of submetrics and values are
&'s avatar
&amp; committed
590
            functions that aggregate a list of metric scores
591
592
593
        """
        pass

594
    @abstractmethod
595
596
597
    def higher_is_better(self):
        """
        :returns: {str: bool}
Fabrizio Milo's avatar
Fabrizio Milo committed
598
            A dictionary where keys are the names of submetrics and values are
599
600
601
602
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
603
    def fewshot_description(self):
604
        import warnings
Fabrizio Milo's avatar
Fabrizio Milo committed
605

606
        warnings.warn(
Jonathan Tow's avatar
Jonathan Tow committed
607
            "`fewshot_description` will be removed in futures versions. Pass "
608
            "any custom descriptions to the `evaluate` function instead.",
Fabrizio Milo's avatar
Fabrizio Milo committed
609
610
            DeprecationWarning,
        )
Jason Phang's avatar
checkin  
Jason Phang committed
611
612
        return ""

613
    @utils.positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
614
615
616
617
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        """Returns a fewshot context string that is made up of a prepended description
618
619
620
621
622
623
624
625
626
627
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
628
            WARNING: This is currently a required arg although it's optionalized with a default `None`.
629
630
631
632
633
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
Fabrizio Milo's avatar
Fabrizio Milo committed
634
635
636
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
637
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
638
            "The `provide_description` arg will be removed in future versions. To prepend "
639
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
640
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
641
        )
642
643
        if provide_description is not None:
            # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
644
645
646
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
647

648
        description = description + "\n\n" if description else ""
649

650
651
        if num_fewshot == 0:
            labeled_examples = ""
652
        else:
653
654
655
656
657
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
658
                    self._fewshot_docs = list(
Fabrizio Milo's avatar
Fabrizio Milo committed
659
660
661
                        self.validation_docs()
                        if self.has_validation_docs()
                        else self.test_docs()
662
                    )
663

664
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
665

666
667
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
668

Fabrizio Milo's avatar
Fabrizio Milo committed
669
670
671
672
673
674
675
676
677
            labeled_examples = (
                "\n\n".join(
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex
                    ]
                )
                + "\n\n"
            )
Leo Gao's avatar
Update  
Leo Gao committed
678

679
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
680
681
682
        return description + labeled_examples + example


Jon Tow's avatar
Jon Tow committed
683
class MultipleChoiceTask(Task):
Leo Gao's avatar
Leo Gao committed
684
    def doc_to_target(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
685
        return " " + doc["choices"][doc["gold"]]
Leo Gao's avatar
Leo Gao committed
686

Leo Gao's avatar
Leo Gao committed
687
688
    def construct_requests(self, doc, ctx):
        lls = [
Fabrizio Milo's avatar
Fabrizio Milo committed
689
            rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
Leo Gao's avatar
Leo Gao committed
690
691
692
693
694
695
696
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Fabrizio Milo's avatar
Fabrizio Milo committed
697
        acc = 1.0 if np.argmax(results) == gold else 0.0
698
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Fabrizio Milo's avatar
Fabrizio Milo committed
699
        acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
Leo Gao's avatar
Leo Gao committed
700
701

        return {
Leo Gao's avatar
Leo Gao committed
702
703
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
704
        }
Fabrizio Milo's avatar
Fabrizio Milo committed
705

Leo Gao's avatar
Leo Gao committed
706
707
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
708
709
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
710
        }
Fabrizio Milo's avatar
Fabrizio Milo committed
711

Leo Gao's avatar
Leo Gao committed
712
713
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
714
715
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
716
717
718
        }


Jason Phang's avatar
Jason Phang committed
719
class PerplexityTask(Task, abc.ABC):
720
721
722
723
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return True

Jason Phang's avatar
Jason Phang committed
724
725
726
727
728
729
730
    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

Fabrizio Milo's avatar
Fabrizio Milo committed
731
732
733
734
735
736
737
738
739
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        assert (
            num_fewshot == 0
        ), "The number of fewshot examples must be 0 for perplexity tasks."
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
740
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
741
            "The `provide_description` arg will be removed in future versions. To prepend "
Jonathan Tow's avatar
Jonathan Tow committed
742
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
743
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
744
        )
745
746
        if provide_description is not None:
            # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
747
748
749
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
750

Jason Phang's avatar
Jason Phang committed
751
752
753
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
754
755
756
757
758
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
759

760
761
762
    def doc_to_decontamination_query(self, doc):
        return doc

Jason Phang's avatar
Jason Phang committed
763
    def doc_to_text(self, doc):
764
        return ""
Jason Phang's avatar
Jason Phang committed
765
766

    def doc_to_target(self, doc):
767
        return doc
Jason Phang's avatar
Jason Phang committed
768
769
770

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
771
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
772
773
774
        return req

    def process_results(self, doc, results):
Fabrizio Milo's avatar
Fabrizio Milo committed
775
        (loglikelihood,) = results
Leo Gao's avatar
Leo Gao committed
776
        words = self.count_words(doc)
777
        bytes_ = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
778
        return {
Leo Gao's avatar
Leo Gao committed
779
            "word_perplexity": (loglikelihood, words),
780
            "byte_perplexity": (loglikelihood, bytes_),
781
            "bits_per_byte": (loglikelihood, bytes_),
Jason Phang's avatar
Jason Phang committed
782
783
784
785
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
786
787
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
788
            "bits_per_byte": bits_per_byte,
Jason Phang's avatar
Jason Phang committed
789
790
        }

791
792
    @classmethod
    def count_bytes(cls, doc):
Leo Gao's avatar
Leo Gao committed
793
        return len(doc.encode("utf-8"))
794
795
796

    @classmethod
    def count_words(cls, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
797
        """Downstream tasks with custom word boundaries should override this!"""
Leo Gao's avatar
Leo Gao committed
798
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
799

Jason Phang's avatar
Jason Phang committed
800

Leo Gao's avatar
Leo Gao committed
801
802
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
Fabrizio Milo's avatar
Fabrizio Milo committed
803
    return hashlib.sha256(dat.encode("utf-8")).hexdigest()
Leo Gao's avatar
Leo Gao committed
804
805


Leo Gao's avatar
Leo Gao committed
806
807
class CacheHook:
    def __init__(self, cachinglm):
Fabrizio Milo's avatar
Fabrizio Milo committed
808
        if cachinglm is None:
Leo Gao's avatar
Leo Gao committed
809
810
811
812
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
Fabrizio Milo's avatar
Fabrizio Milo committed
813

Leo Gao's avatar
Leo Gao committed
814
815
816
817
818
819
820
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
821
822
class CachingLM:
    def __init__(self, lm, cache_db):
823
824
825
826
827
828
829
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
Leo Gao's avatar
Leo Gao committed
830
831
        self.lm = lm
        self.cache_db = cache_db
832
833
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
834
835
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
836
837
838
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
839
840
841
842
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
Fabrizio Milo's avatar
Fabrizio Milo committed
843

Leo Gao's avatar
Leo Gao committed
844
845
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
846
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
847
848
849
850
851
852
853
854
855
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
Fabrizio Milo's avatar
Fabrizio Milo committed
856

857
            # actually run the LM on the requests that do not have cached results
Leo Gao's avatar
Leo Gao committed
858
859
860
861
862
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
863
864
                while res[resptr] is not None:
                    resptr += 1
Leo Gao's avatar
Leo Gao committed
865
866
867
868

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
869
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
870
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
871
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
872
873

            return res
Fabrizio Milo's avatar
Fabrizio Milo committed
874

Leo Gao's avatar
Leo Gao committed
875
        return fn
Fabrizio Milo's avatar
Fabrizio Milo committed
876

Leo Gao's avatar
Leo Gao committed
877
878
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
879

Jason Phang's avatar
Jason Phang committed
880

881
REQUEST_RETURN_LENGTHS = {
Fabrizio Milo's avatar
Fabrizio Milo committed
882
883
    "loglikelihood": 2,
    "greedy_until": None,
884
    "multiple_temperature_sample_until": None,
Fabrizio Milo's avatar
Fabrizio Milo committed
885
    "loglikelihood_rolling": None,
886
887
888
}


889
class Request:
Leo Gao's avatar
Leo Gao committed
890
891
    def __init__(self, request_type, args, index=None):
        if request_type not in REQUEST_RETURN_LENGTHS.keys():
Fabrizio Milo's avatar
Fabrizio Milo committed
892
893
894
            raise NotImplementedError(
                "The request type {} is not implemented!".format(request_type)
            )
Leo Gao's avatar
Leo Gao committed
895

Leo Gao's avatar
Leo Gao committed
896
        self.request_type = request_type
897
898
        self.args = args
        self.index = index
Fabrizio Milo's avatar
Fabrizio Milo committed
899

900
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
901
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Fabrizio Milo's avatar
Fabrizio Milo committed
902
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
903
904
        for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
            yield Request(self.request_type, self.args, i)
Fabrizio Milo's avatar
Fabrizio Milo committed
905

906
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
907
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Fabrizio Milo's avatar
Fabrizio Milo committed
908
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
909
        return Request(self.request_type, self.args, i)
Fabrizio Milo's avatar
Fabrizio Milo committed
910

Leo Gao's avatar
Leo Gao committed
911
    def __eq__(self, other):
Fabrizio Milo's avatar
Fabrizio Milo committed
912
913
914
915
916
        return (
            self.request_type == other.request_type
            and self.args == other.args
            and self.index == other.index
        )
Leo Gao's avatar
Leo Gao committed
917

Leo Gao's avatar
Leo Gao committed
918
    def __repr__(self):
Leo Gao's avatar
Leo Gao committed
919
        return f"Req_{self.request_type}{self.args}[{self.index}]\n"
920

Jason Phang's avatar
Jason Phang committed
921

Leo Gao's avatar
Leo Gao committed
922
923
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
924
925
        def fn(*args):
            return Request(attr, args)
Fabrizio Milo's avatar
Fabrizio Milo committed
926

Leo Gao's avatar
Leo Gao committed
927
928
929
930
        return fn


rf = RequestFactory()