base.py 32.3 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import abc
2
from typing import Iterable
thefazzer's avatar
thefazzer committed
3
import numpy as np
4
import random
Leo Gao's avatar
Leo Gao committed
5
import re
6
7
8
import os
import json
import hashlib
Jonathan Tow's avatar
Jonathan Tow committed
9
import datasets
10
from sqlitedict import SqliteDict
11
from tqdm import tqdm
12
import torch
Leo Gao's avatar
Leo Gao committed
13
import torch.nn.functional as F
14
from accelerate import find_executable_batch_size
15
import gc
&'s avatar
& committed
16

17
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
18
from lm_eval import utils
19
from abc import abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
20

Jason Phang's avatar
Jason Phang committed
21

Leo Gao's avatar
Leo Gao committed
22
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
23
24
25
    def __init__(self):
        self.cache_hook = CacheHook(None)

26
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
27
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
28
        """Compute log-likelihood of generating a continuation from a context.
Fabrizio Milo's avatar
Fabrizio Milo committed
29
        Downstream tasks should attempt to use loglikelihood instead of other
Leo Gao's avatar
Leo Gao committed
30
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
31

Leo Gao's avatar
Leo Gao committed
32
33
34
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Fabrizio Milo's avatar
Fabrizio Milo committed
35
                Context string. Implementations of LM must be able to handle an
Leo Gao's avatar
Leo Gao committed
36
                empty context string.
Leo Gao's avatar
Leo Gao committed
37
            continuation: str
Fabrizio Milo's avatar
Fabrizio Milo committed
38
39
                The continuation over which log likelihood will be calculated. If
                there is a word boundary, the space should be in the continuation.
Leo Gao's avatar
Leo Gao committed
40
41
42
43
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
44
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
45
            isgreedy:
Jason Phang's avatar
Jason Phang committed
46
47
48
49
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

50
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
51
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
52
53
54
55
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
Fabrizio Milo's avatar
Fabrizio Milo committed
56
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
Jason Phang's avatar
Jason Phang committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
89
90
91
        """
        pass

&'s avatar
& committed
92
    # TODO: Add an optional max length
93
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
94
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
95
96
97
98
99
100
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
101
            until: [str]
Fabrizio Milo's avatar
Fabrizio Milo committed
102
                The string sequences to generate until. These string sequences
Leo Gao's avatar
Leo Gao committed
103
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
104
105
106
107
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
108
        """
Leo Gao's avatar
Leo Gao committed
109
110
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
111
    @classmethod
112
113
    def create_from_arg_string(cls, arg_string, additional_config=None):
        additional_config = {} if additional_config is None else additional_config
114
115
116
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
117

Leo Gao's avatar
Leo Gao committed
118
119
120
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
121

122
class BaseLM(LM):
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    @property
    @abstractmethod
    def eot_token_id(self):
        pass

    @property
    @abstractmethod
    def max_length(self):
        pass

    @property
    @abstractmethod
    def max_gen_toks(self):
        pass

    @property
    @abstractmethod
    def batch_size(self):
        pass

    @property
    @abstractmethod
    def device(self):
        pass

148
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
149
150
151
    def tok_encode(self, string: str):
        pass

152
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
153
154
    def tok_decode(self, tokens: Iterable[int]):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
155

156
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
157
158
    def _model_generate(self, context, max_length, eos_token_id):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
159

160
161
    @abstractmethod
    def _model_call(self, inps):
Jason Phang's avatar
gpt3  
Jason Phang committed
162
        """
163
164
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Jason Phang's avatar
gpt3  
Jason Phang committed
165

166
        returns: a torch tensor of shape [batch, sequence, vocab] with the
167
        logits returned from the model
168
169
        """
        pass
170

Leo Gao's avatar
Leo Gao committed
171
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    # TODO: enforce this somehow

    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
191
192
193
194
195
196
197
198
199
        
        # automatic batch size detection for vectorization
        adaptive_batch_size = None
        if self.batch_size == 'auto': 
            # using rolling window with maximum context
            print('Passed argument batch_size = auto. Detecting largest batch size')
            @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
            def forward_batch(batch_size):
                test_batch = torch.ones((batch_size, self.max_length), device=self.device).long()
200
201
                for _ in range(5): 
                    out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
202
203
204
205
206
                return batch_size
            
            batch_size = forward_batch() 
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size
207
208

        loglikelihoods = []
Fabrizio Milo's avatar
Fabrizio Milo committed
209
210
211
212
213
214
215
216
217
218
219
220
        for (string,) in tqdm(requests):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
221
222
223

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

224
225
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
Fabrizio Milo's avatar
Fabrizio Milo committed
226
            string_nll = self._loglikelihood_tokens(
227
                rolling_token_windows, disable_tqdm=True, override_bs = adaptive_batch_size
Fabrizio Milo's avatar
Fabrizio Milo committed
228
229
            )

230
231
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
Fabrizio Milo's avatar
Fabrizio Milo committed
232

233
234
235
236
237
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

238
    def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs = None):
239
240
241
242
243
244
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
245
246
247
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
248
249
250
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
251
            return -len(toks), tuple(toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
252

253
        
Fabrizio Milo's avatar
Fabrizio Milo committed
254
        re_ord = utils.Reorderer(requests, _collate)
255
256
257

        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
258
        _, context_enc, continuation_enc = re_ord.get_reordered()[0]
259
260
        max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
        if (self.batch_size == 'auto'):
261
            
262
263
264
265
            if override_bs is None:
                print('Passed argument batch_size = auto. Detecting largest batch size')
                @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
                def forward_batch(batch_size):
266
                    test_batch = torch.ones((batch_size, max_context), device=self.device).long()
267
268
                    for _ in range(5): 
                        out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
269
270
271
                    return batch_size
                
                batch_size = forward_batch() 
272
                print(f"Determined largest batch size: {batch_size}")
273
                adaptive_batch_size = batch_size
274
                
275
276
            else:
                adaptive_batch_size = override_bs
277

Fabrizio Milo's avatar
Fabrizio Milo committed
278
        for chunk in utils.chunks(
279
            tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size if self.batch_size != "auto" else adaptive_batch_size
Fabrizio Milo's avatar
Fabrizio Milo committed
280
        ):
281
            inps = []
282
            cont_toks_list = []
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
299
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
300
                # gpt2    \               \
301
302
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice
303
304
305

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
Fabrizio Milo's avatar
Fabrizio Milo committed
306
307
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
308
                ).to(self.device)
Fabrizio Milo's avatar
Fabrizio Milo committed
309
                (inplen,) = inp.shape
310
311
312
313

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
Fabrizio Milo's avatar
Fabrizio Milo committed
314
315
316
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )
317

318
                # pad length from seq to padding_length
Fabrizio Milo's avatar
Fabrizio Milo committed
319
320
321
322
323
324
325
326
327
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )
328

329
330
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
331
332
                inplens.append(inplen)

333
            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
Fabrizio Milo's avatar
Fabrizio Milo committed
334
335
336
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]
337

Fabrizio Milo's avatar
Fabrizio Milo committed
338
339
340
            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
            ):
341

342
343
                # Slice to original seq length
                contlen = len(cont_toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
344
345
346
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]
347

348
                # Check if per-token argmax is exactly equal to continuation
349
                greedy_tokens = logits.argmax(dim=-1)
Fabrizio Milo's avatar
Fabrizio Milo committed
350
351
352
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
353
354
                max_equal = (greedy_tokens == cont_toks).all()

355
356
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
Fabrizio Milo's avatar
Fabrizio Milo committed
357
358
359
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]
360

361
                # Answer: (log prob, is-exact-match)
362
363
364
365
366
367
368
369
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

Fabrizio Milo's avatar
Fabrizio Milo committed
370
        return re_ord.get_original(res)
Fabrizio Milo's avatar
Fabrizio Milo committed
371

372
    def greedy_until(self, requests):
Fabrizio Milo's avatar
Fabrizio Milo committed
373
        # TODO: implement fully general `until` that handles until that are
374
        #       multiple tokens or that span multiple tokens correctly
375
376
377
378
379
380

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
381
            return len(toks), x[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
382

Fabrizio Milo's avatar
Fabrizio Milo committed
383
        re_ord = utils.Reorderer(requests, _collate)
384

Fabrizio Milo's avatar
Fabrizio Milo committed
385
        for context, until in tqdm(re_ord.get_reordered()):
386
387
            if isinstance(until, str):
                until = [until]
388

Fabrizio Milo's avatar
Fabrizio Milo committed
389
            (primary_until,) = self.tok_encode(until[0])
390

Fabrizio Milo's avatar
Fabrizio Milo committed
391
392
393
            context_enc = torch.tensor(
                [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
            ).to(self.device)
394

Fabrizio Milo's avatar
Fabrizio Milo committed
395
396
397
398
399
            cont = self._model_generate(
                context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
            )

            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
400
401
402

            for term in until:
                s = s.split(term)[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
403

404
405
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
Fabrizio Milo's avatar
Fabrizio Milo committed
406

407
            res.append(s)
Fabrizio Milo's avatar
Fabrizio Milo committed
408

Fabrizio Milo's avatar
Fabrizio Milo committed
409
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
410

Leo Gao's avatar
Leo Gao committed
411

412
class Task(abc.ABC):
&'s avatar
&amp; committed
413
414
415
416
417
418
419
420
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Jonathan Tow's avatar
Jonathan Tow committed
421

Jon Tow's avatar
Jon Tow committed
422
423
    # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
    # or a path to a custom `datasets` loading script.
Jonathan Tow's avatar
Jonathan Tow committed
424
425
426
427
428
    DATASET_PATH: str = None

    # The name of a subset within `DATASET_PATH`.
    DATASET_NAME: str = None

Jon Tow's avatar
Jon Tow committed
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        """
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
        """
        self.download(data_dir, cache_dir, download_mode)
453
        self._training_docs = None
454
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
455

Jon Tow's avatar
Jon Tow committed
456
    def download(self, data_dir=None, cache_dir=None, download_mode=None):
Fabrizio Milo's avatar
Fabrizio Milo committed
457
        """Downloads and returns the task dataset.
Jonathan Tow's avatar
Jonathan Tow committed
458
459
        Override this method to download the dataset from a custom API.

Jon Tow's avatar
Jon Tow committed
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
Jonathan Tow's avatar
Jonathan Tow committed
480
481
482
483
        """
        self.dataset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
Jon Tow's avatar
Jon Tow committed
484
485
            data_dir=data_dir,
            cache_dir=cache_dir,
Fabrizio Milo's avatar
Fabrizio Milo committed
486
            download_mode=download_mode,
Jonathan Tow's avatar
Jonathan Tow committed
487
        )
sdtblck's avatar
sdtblck committed
488

489
490
491
492
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return False

493
    @abstractmethod
494
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
495
        """Whether the task has a training set"""
496
        pass
497

498
    @abstractmethod
499
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
500
501
502
        """Whether the task has a validation set"""
        pass

503
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
504
505
    def has_test_docs(self):
        """Whether the task has a test set"""
506
507
        pass

Leo Gao's avatar
Leo Gao committed
508
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
509
510
511
512
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
513
        return []
514

Leo Gao's avatar
Leo Gao committed
515
    def validation_docs(self):
516
517
518
519
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
520
        return []
521

Leo Gao's avatar
Leo Gao committed
522
    def test_docs(self):
523
524
525
526
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
527
        return []
Leo Gao's avatar
Leo Gao committed
528

Jon Tow's avatar
Jon Tow committed
529
530
531
532
    def _process_doc(self, doc):
        """
        Override this to process (detokenize, strip, replace, etc.) individual
        documents. This can be used in a map over documents of a data split.
Jon Tow's avatar
Jon Tow committed
533
        E.g. `map(self._process_doc, self.dataset["validation"])`
Jon Tow's avatar
Jon Tow committed
534
535
536
537
538
539

        :return: dict
            The processed version of the specified `doc`.
        """
        return doc

540
    def fewshot_examples(self, k, rnd):
541
542
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
543

Leo Gao's avatar
Leo Gao committed
544
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
545

546
    def doc_to_decontamination_query(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
547
548
549
550
        print(
            "Override doc_to_decontamination_query with document specific decontamination query."
        )
        assert False
551

552
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
553
554
555
    def doc_to_text(self, doc):
        pass

556
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
557
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
558
        pass
Leo Gao's avatar
Leo Gao committed
559

560
    @abstractmethod
561
    def construct_requests(self, doc, ctx):
Fabrizio Milo's avatar
Fabrizio Milo committed
562
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
563
564
        Requests which will be sent to the LM.

565
566
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
567
        :param ctx: str
Fabrizio Milo's avatar
Fabrizio Milo committed
568
            The context string, generated by fewshot_context. This includes the natural
569
            language description, as well as the few shot examples, and the question
Fabrizio Milo's avatar
Fabrizio Milo committed
570
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
571
        """
Leo Gao's avatar
Leo Gao committed
572
        pass
573

574
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
575
    def process_results(self, doc, results):
Fabrizio Milo's avatar
Fabrizio Milo committed
576
577
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
578
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
579
580
581
582
583

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
584
        """
Leo Gao's avatar
Leo Gao committed
585
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
586

587
    @abstractmethod
588
589
    def aggregation(self):
        """
&'s avatar
&amp; committed
590
        :returns: {str: [metric_score] -> float}
Fabrizio Milo's avatar
Fabrizio Milo committed
591
            A dictionary where keys are the names of submetrics and values are
&'s avatar
&amp; committed
592
            functions that aggregate a list of metric scores
593
594
595
        """
        pass

596
    @abstractmethod
597
598
599
    def higher_is_better(self):
        """
        :returns: {str: bool}
Fabrizio Milo's avatar
Fabrizio Milo committed
600
            A dictionary where keys are the names of submetrics and values are
601
602
603
604
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
605
    def fewshot_description(self):
606
        import warnings
Fabrizio Milo's avatar
Fabrizio Milo committed
607

608
        warnings.warn(
Jonathan Tow's avatar
Jonathan Tow committed
609
            "`fewshot_description` will be removed in futures versions. Pass "
610
            "any custom descriptions to the `evaluate` function instead.",
Fabrizio Milo's avatar
Fabrizio Milo committed
611
612
            DeprecationWarning,
        )
Jason Phang's avatar
checkin  
Jason Phang committed
613
614
        return ""

615
    @utils.positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
616
617
618
619
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        """Returns a fewshot context string that is made up of a prepended description
620
621
622
623
624
625
626
627
628
629
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
630
            WARNING: This is currently a required arg although it's optionalized with a default `None`.
631
632
633
634
635
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
Fabrizio Milo's avatar
Fabrizio Milo committed
636
637
638
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
639
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
640
            "The `provide_description` arg will be removed in future versions. To prepend "
641
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
642
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
643
        )
644
645
        if provide_description is not None:
            # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
646
647
648
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
649

650
        description = description + "\n\n" if description else ""
651

652
653
        if num_fewshot == 0:
            labeled_examples = ""
654
        else:
655
656
657
658
659
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
660
                    self._fewshot_docs = list(
Fabrizio Milo's avatar
Fabrizio Milo committed
661
662
663
                        self.validation_docs()
                        if self.has_validation_docs()
                        else self.test_docs()
664
                    )
665

666
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
667

668
669
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
670

Fabrizio Milo's avatar
Fabrizio Milo committed
671
672
673
674
675
676
677
678
679
            labeled_examples = (
                "\n\n".join(
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex
                    ]
                )
                + "\n\n"
            )
Leo Gao's avatar
Update  
Leo Gao committed
680

681
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
682
683
684
        return description + labeled_examples + example


Jon Tow's avatar
Jon Tow committed
685
class MultipleChoiceTask(Task):
Leo Gao's avatar
Leo Gao committed
686
    def doc_to_target(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
687
        return " " + doc["choices"][doc["gold"]]
Leo Gao's avatar
Leo Gao committed
688

Leo Gao's avatar
Leo Gao committed
689
690
    def construct_requests(self, doc, ctx):
        lls = [
Fabrizio Milo's avatar
Fabrizio Milo committed
691
            rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
Leo Gao's avatar
Leo Gao committed
692
693
694
695
696
697
698
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Fabrizio Milo's avatar
Fabrizio Milo committed
699
        acc = 1.0 if np.argmax(results) == gold else 0.0
700
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Fabrizio Milo's avatar
Fabrizio Milo committed
701
        acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
Leo Gao's avatar
Leo Gao committed
702
703

        return {
Leo Gao's avatar
Leo Gao committed
704
705
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
706
        }
Fabrizio Milo's avatar
Fabrizio Milo committed
707

Leo Gao's avatar
Leo Gao committed
708
709
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
710
711
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
712
        }
Fabrizio Milo's avatar
Fabrizio Milo committed
713

Leo Gao's avatar
Leo Gao committed
714
715
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
716
717
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
718
719
720
        }


Jason Phang's avatar
Jason Phang committed
721
class PerplexityTask(Task, abc.ABC):
722
723
724
725
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return True

Jason Phang's avatar
Jason Phang committed
726
727
728
729
730
731
732
    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

Fabrizio Milo's avatar
Fabrizio Milo committed
733
734
735
736
737
738
739
740
741
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        assert (
            num_fewshot == 0
        ), "The number of fewshot examples must be 0 for perplexity tasks."
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
742
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
743
            "The `provide_description` arg will be removed in future versions. To prepend "
Jonathan Tow's avatar
Jonathan Tow committed
744
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
745
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
746
        )
747
748
        if provide_description is not None:
            # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
749
750
751
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
752

Jason Phang's avatar
Jason Phang committed
753
754
755
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
756
757
758
759
760
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
761

762
763
764
    def doc_to_decontamination_query(self, doc):
        return doc

Jason Phang's avatar
Jason Phang committed
765
    def doc_to_text(self, doc):
766
        return ""
Jason Phang's avatar
Jason Phang committed
767
768

    def doc_to_target(self, doc):
769
        return doc
Jason Phang's avatar
Jason Phang committed
770
771
772

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
773
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
774
775
776
        return req

    def process_results(self, doc, results):
Fabrizio Milo's avatar
Fabrizio Milo committed
777
        (loglikelihood,) = results
Leo Gao's avatar
Leo Gao committed
778
        words = self.count_words(doc)
779
        bytes_ = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
780
        return {
Leo Gao's avatar
Leo Gao committed
781
            "word_perplexity": (loglikelihood, words),
782
            "byte_perplexity": (loglikelihood, bytes_),
783
            "bits_per_byte": (loglikelihood, bytes_),
Jason Phang's avatar
Jason Phang committed
784
785
786
787
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
788
789
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
790
            "bits_per_byte": bits_per_byte,
Jason Phang's avatar
Jason Phang committed
791
792
        }

793
794
    @classmethod
    def count_bytes(cls, doc):
Leo Gao's avatar
Leo Gao committed
795
        return len(doc.encode("utf-8"))
796
797
798

    @classmethod
    def count_words(cls, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
799
        """Downstream tasks with custom word boundaries should override this!"""
Leo Gao's avatar
Leo Gao committed
800
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
801

Jason Phang's avatar
Jason Phang committed
802

Leo Gao's avatar
Leo Gao committed
803
804
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
Fabrizio Milo's avatar
Fabrizio Milo committed
805
    return hashlib.sha256(dat.encode("utf-8")).hexdigest()
Leo Gao's avatar
Leo Gao committed
806
807


Leo Gao's avatar
Leo Gao committed
808
809
class CacheHook:
    def __init__(self, cachinglm):
Fabrizio Milo's avatar
Fabrizio Milo committed
810
        if cachinglm is None:
Leo Gao's avatar
Leo Gao committed
811
812
813
814
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
Fabrizio Milo's avatar
Fabrizio Milo committed
815

Leo Gao's avatar
Leo Gao committed
816
817
818
819
820
821
822
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
823
824
class CachingLM:
    def __init__(self, lm, cache_db):
825
826
827
828
829
830
831
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
Leo Gao's avatar
Leo Gao committed
832
833
        self.lm = lm
        self.cache_db = cache_db
834
835
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
836
837
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
838
839
840
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
841
842
843
844
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
Fabrizio Milo's avatar
Fabrizio Milo committed
845

Leo Gao's avatar
Leo Gao committed
846
847
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
848
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
849
850
851
852
853
854
855
856
857
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
Fabrizio Milo's avatar
Fabrizio Milo committed
858

859
            # actually run the LM on the requests that do not have cached results
Leo Gao's avatar
Leo Gao committed
860
861
862
863
864
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
865
866
                while res[resptr] is not None:
                    resptr += 1
Leo Gao's avatar
Leo Gao committed
867
868
869
870

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
871
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
872
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
873
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
874
875

            return res
Fabrizio Milo's avatar
Fabrizio Milo committed
876

Leo Gao's avatar
Leo Gao committed
877
        return fn
Fabrizio Milo's avatar
Fabrizio Milo committed
878

Leo Gao's avatar
Leo Gao committed
879
880
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
881

Jason Phang's avatar
Jason Phang committed
882

883
REQUEST_RETURN_LENGTHS = {
Fabrizio Milo's avatar
Fabrizio Milo committed
884
885
886
    "loglikelihood": 2,
    "greedy_until": None,
    "loglikelihood_rolling": None,
887
888
889
}


890
class Request:
Leo Gao's avatar
Leo Gao committed
891
892
    def __init__(self, request_type, args, index=None):
        if request_type not in REQUEST_RETURN_LENGTHS.keys():
Fabrizio Milo's avatar
Fabrizio Milo committed
893
894
895
            raise NotImplementedError(
                "The request type {} is not implemented!".format(request_type)
            )
Leo Gao's avatar
Leo Gao committed
896

Leo Gao's avatar
Leo Gao committed
897
        self.request_type = request_type
898
899
        self.args = args
        self.index = index
Fabrizio Milo's avatar
Fabrizio Milo committed
900

901
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
902
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Fabrizio Milo's avatar
Fabrizio Milo committed
903
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
904
905
        for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
            yield Request(self.request_type, self.args, i)
Fabrizio Milo's avatar
Fabrizio Milo committed
906

907
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
908
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Fabrizio Milo's avatar
Fabrizio Milo committed
909
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
910
        return Request(self.request_type, self.args, i)
Fabrizio Milo's avatar
Fabrizio Milo committed
911

Leo Gao's avatar
Leo Gao committed
912
    def __eq__(self, other):
Fabrizio Milo's avatar
Fabrizio Milo committed
913
914
915
916
917
        return (
            self.request_type == other.request_type
            and self.args == other.args
            and self.index == other.index
        )
Leo Gao's avatar
Leo Gao committed
918

Leo Gao's avatar
Leo Gao committed
919
    def __repr__(self):
Leo Gao's avatar
Leo Gao committed
920
        return f"Req_{self.request_type}{self.args}[{self.index}]\n"
921

Jason Phang's avatar
Jason Phang committed
922

Leo Gao's avatar
Leo Gao committed
923
924
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
925
926
        def fn(*args):
            return Request(attr, args)
Fabrizio Milo's avatar
Fabrizio Milo committed
927

Leo Gao's avatar
Leo Gao committed
928
929
930
931
        return fn


rf = RequestFactory()