base.py 33.3 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import abc
2
from typing import Iterable
thefazzer's avatar
thefazzer committed
3
import numpy as np
4
import random
Leo Gao's avatar
Leo Gao committed
5
import re
6
7
8
import os
import json
import hashlib
Jonathan Tow's avatar
Jonathan Tow committed
9
import datasets
10
from sqlitedict import SqliteDict
11
from tqdm import tqdm
12
import torch
Leo Gao's avatar
Leo Gao committed
13
import torch.nn.functional as F
14
from accelerate import find_executable_batch_size
&'s avatar
& committed
15

16
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
17
from lm_eval import utils
18
from abc import abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
19

Jason Phang's avatar
Jason Phang committed
20

Leo Gao's avatar
Leo Gao committed
21
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
22
23
24
    def __init__(self):
        self.cache_hook = CacheHook(None)

25
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
26
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
27
        """Compute log-likelihood of generating a continuation from a context.
Fabrizio Milo's avatar
Fabrizio Milo committed
28
        Downstream tasks should attempt to use loglikelihood instead of other
Leo Gao's avatar
Leo Gao committed
29
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
30

Leo Gao's avatar
Leo Gao committed
31
32
33
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Fabrizio Milo's avatar
Fabrizio Milo committed
34
                Context string. Implementations of LM must be able to handle an
Leo Gao's avatar
Leo Gao committed
35
                empty context string.
Leo Gao's avatar
Leo Gao committed
36
            continuation: str
Fabrizio Milo's avatar
Fabrizio Milo committed
37
38
                The continuation over which log likelihood will be calculated. If
                there is a word boundary, the space should be in the continuation.
Leo Gao's avatar
Leo Gao committed
39
40
41
42
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
43
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
44
            isgreedy:
Jason Phang's avatar
Jason Phang committed
45
46
47
48
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

49
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
50
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
51
52
53
54
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
Fabrizio Milo's avatar
Fabrizio Milo committed
55
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
Jason Phang's avatar
Jason Phang committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
88
89
90
        """
        pass

&'s avatar
& committed
91
    # TODO: Add an optional max length
92
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
93
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
94
95
96
97
98
99
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
100
            until: [str]
Fabrizio Milo's avatar
Fabrizio Milo committed
101
                The string sequences to generate until. These string sequences
Leo Gao's avatar
Leo Gao committed
102
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
103
104
105
106
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
107
        """
Leo Gao's avatar
Leo Gao committed
108
109
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
110
    @classmethod
111
112
    def create_from_arg_string(cls, arg_string, additional_config=None):
        additional_config = {} if additional_config is None else additional_config
113
114
115
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
116

Leo Gao's avatar
Leo Gao committed
117
118
119
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
120

121
class BaseLM(LM):
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    @property
    @abstractmethod
    def eot_token_id(self):
        pass

    @property
    @abstractmethod
    def max_length(self):
        pass

    @property
    @abstractmethod
    def max_gen_toks(self):
        pass

    @property
    @abstractmethod
    def batch_size(self):
        pass

    @property
    @abstractmethod
    def device(self):
        pass

147
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
148
149
150
    def tok_encode(self, string: str):
        pass

151
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
152
153
    def tok_decode(self, tokens: Iterable[int]):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
154

155
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
156
157
    def _model_generate(self, context, max_length, eos_token_id):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
158

159
160
    @abstractmethod
    def _model_call(self, inps):
Jason Phang's avatar
gpt3  
Jason Phang committed
161
        """
162
163
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Jason Phang's avatar
gpt3  
Jason Phang committed
164

165
        returns: a torch tensor of shape [batch, sequence, vocab] with the
166
        logits returned from the model
167
168
        """
        pass
169

Leo Gao's avatar
Leo Gao committed
170
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
171
172
    # TODO: enforce this somehow

gakada's avatar
gakada committed
173
174
175
176
177
178
179
180
181
182
183
184
185
    def _encode_pair(self, context, continuation):
        whole_enc = self.tok_encode(context + continuation)
        whole_enc_len = len(whole_enc)
        context_enc = self.tok_encode(context)
        context_enc_len = len(context_enc)
        if context_enc_len < whole_enc_len:
            continuation_enc = whole_enc[context_enc_len:]
        else:
            continuation_enc = self.tok_encode(continuation)
            continuation_enc_len = len(continuation_enc)
            context_enc = whole_enc[:-continuation_enc_len]
        return context_enc, continuation_enc

186
187
188
189
190
    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
gakada's avatar
gakada committed
191
                context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(continuation)
192
            else:
gakada's avatar
gakada committed
193
                context_enc, continuation_enc = self._encode_pair(context, continuation)
194
195
196
197
198
199
200

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
201

202
203
        # automatic batch size detection for vectorization
        adaptive_batch_size = None
204
        if self.batch_size == "auto":
205
            # using rolling window with maximum context
206
207
208
209
210
            print("Passed argument batch_size = auto. Detecting largest batch size")

            @find_executable_batch_size(
                starting_batch_size=512
            )  # if OOM, then halves batch_size and tries again
211
            def forward_batch(batch_size):
212
213
214
215
216
                test_batch = torch.ones(
                    (batch_size, self.max_length), device=self.device
                ).long()
                for _ in range(5):
                    _ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
217
                return batch_size
218
219

            batch_size = forward_batch()
220
221
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size
222
223

        loglikelihoods = []
Fabrizio Milo's avatar
Fabrizio Milo committed
224
225
226
227
228
229
230
231
232
233
234
235
        for (string,) in tqdm(requests):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
236
237
238

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

239
240
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
Fabrizio Milo's avatar
Fabrizio Milo committed
241
            string_nll = self._loglikelihood_tokens(
242
243
244
                rolling_token_windows,
                disable_tqdm=True,
                override_bs=adaptive_batch_size,
Fabrizio Milo's avatar
Fabrizio Milo committed
245
246
            )

247
248
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
Fabrizio Milo's avatar
Fabrizio Milo committed
249

250
251
252
253
254
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

255
    def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs=None):
256
257
258
259
260
261
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
262
263
264
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
265
266
267
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
268
            return -len(toks), tuple(toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
269

Fabrizio Milo's avatar
Fabrizio Milo committed
270
        re_ord = utils.Reorderer(requests, _collate)
271
272
273

        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
274
275
276
277
        if len(re_ord.get_reordered()) > 0:
            _, context_enc, continuation_enc = re_ord.get_reordered()[0]
            max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
            if (self.batch_size == 'auto'):
gakada's avatar
gakada committed
278

279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
                if override_bs is None:
                    print('Passed argument batch_size = auto. Detecting largest batch size')
                    @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
                    def forward_batch(batch_size):
                        test_batch = torch.ones((batch_size, max_context), device=self.device).long()
                        for _ in range(5):
                            out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
                        return batch_size

                    batch_size = forward_batch()
                    print(f"Determined largest batch size: {batch_size}")
                    adaptive_batch_size = batch_size

                else:
                    adaptive_batch_size = override_bs
        else:
            adaptive_batch_size = 0 if override_bs is None else override_bs
296

Fabrizio Milo's avatar
Fabrizio Milo committed
297
        for chunk in utils.chunks(
298
299
            tqdm(re_ord.get_reordered(), disable=disable_tqdm),
            self.batch_size if self.batch_size != "auto" else adaptive_batch_size,
Fabrizio Milo's avatar
Fabrizio Milo committed
300
        ):
301
            inps = []
302
            cont_toks_list = []
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
319
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
320
                # gpt2    \               \
321
322
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice
323
324
325

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
Fabrizio Milo's avatar
Fabrizio Milo committed
326
327
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
328
                ).to(self.device)
Fabrizio Milo's avatar
Fabrizio Milo committed
329
                (inplen,) = inp.shape
330
331
332
333

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
Fabrizio Milo's avatar
Fabrizio Milo committed
334
335
336
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )
337

338
                # pad length from seq to padding_length
Fabrizio Milo's avatar
Fabrizio Milo committed
339
340
341
342
343
344
345
346
347
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )
348

349
350
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
351
352
                inplens.append(inplen)

353
            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
Fabrizio Milo's avatar
Fabrizio Milo committed
354
355
356
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]
357

Fabrizio Milo's avatar
Fabrizio Milo committed
358
359
360
            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
            ):
361

362
363
                # Slice to original seq length
                contlen = len(cont_toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
364
365
366
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]
367

368
                # Check if per-token argmax is exactly equal to continuation
369
                greedy_tokens = logits.argmax(dim=-1)
Fabrizio Milo's avatar
Fabrizio Milo committed
370
371
372
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
373
374
                max_equal = (greedy_tokens == cont_toks).all()

375
376
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
Fabrizio Milo's avatar
Fabrizio Milo committed
377
378
379
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]
380

381
                # Answer: (log prob, is-exact-match)
382
383
384
385
386
387
388
389
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

Fabrizio Milo's avatar
Fabrizio Milo committed
390
        return re_ord.get_original(res)
Fabrizio Milo's avatar
Fabrizio Milo committed
391

392
    def greedy_until(self, requests):
Fabrizio Milo's avatar
Fabrizio Milo committed
393
        # TODO: implement fully general `until` that handles until that are
394
        #       multiple tokens or that span multiple tokens correctly
395
396
397
398
399
400

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
401
            return len(toks), x[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
402

Fabrizio Milo's avatar
Fabrizio Milo committed
403
        re_ord = utils.Reorderer(requests, _collate)
404

405
        for context, request_args in tqdm(re_ord.get_reordered()):
406
            until = request_args["until"]
407
408
            if isinstance(until, str):
                until = [until]
409

410
411
412
413
            if until:
                (primary_until,) = self.tok_encode(until[0])
            else:
                primary_until = None
414

Fabrizio Milo's avatar
Fabrizio Milo committed
415
416
417
            context_enc = torch.tensor(
                [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
            ).to(self.device)
418

419
            max_gen_tokens = min(
420
                self.max_gen_toks, request_args.get("max_length", self.max_gen_toks)
421
            )
Fabrizio Milo's avatar
Fabrizio Milo committed
422
            cont = self._model_generate(
423
                context_enc, context_enc.shape[1] + max_gen_tokens, primary_until
Fabrizio Milo's avatar
Fabrizio Milo committed
424
425
426
            )

            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
427
428
429

            for term in until:
                s = s.split(term)[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
430

431
432
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
Fabrizio Milo's avatar
Fabrizio Milo committed
433

434
            res.append(s)
Fabrizio Milo's avatar
Fabrizio Milo committed
435

Fabrizio Milo's avatar
Fabrizio Milo committed
436
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
437

Leo Gao's avatar
Leo Gao committed
438

439
class Task(abc.ABC):
&'s avatar
&amp; committed
440
441
442
443
444
445
446
447
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Jonathan Tow's avatar
Jonathan Tow committed
448

Jon Tow's avatar
Jon Tow committed
449
450
    # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
    # or a path to a custom `datasets` loading script.
Jonathan Tow's avatar
Jonathan Tow committed
451
452
453
454
455
    DATASET_PATH: str = None

    # The name of a subset within `DATASET_PATH`.
    DATASET_NAME: str = None

Jon Tow's avatar
Jon Tow committed
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        """
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
        """
        self.download(data_dir, cache_dir, download_mode)
480
        self._training_docs = None
481
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
482

Jon Tow's avatar
Jon Tow committed
483
    def download(self, data_dir=None, cache_dir=None, download_mode=None):
Fabrizio Milo's avatar
Fabrizio Milo committed
484
        """Downloads and returns the task dataset.
Jonathan Tow's avatar
Jonathan Tow committed
485
486
        Override this method to download the dataset from a custom API.

Jon Tow's avatar
Jon Tow committed
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
Jonathan Tow's avatar
Jonathan Tow committed
507
508
509
510
        """
        self.dataset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
Jon Tow's avatar
Jon Tow committed
511
512
            data_dir=data_dir,
            cache_dir=cache_dir,
Fabrizio Milo's avatar
Fabrizio Milo committed
513
            download_mode=download_mode,
Jonathan Tow's avatar
Jonathan Tow committed
514
        )
sdtblck's avatar
sdtblck committed
515

516
517
518
519
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return False

520
    @abstractmethod
521
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
522
        """Whether the task has a training set"""
523
        pass
524

525
    @abstractmethod
526
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
527
528
529
        """Whether the task has a validation set"""
        pass

530
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
531
532
    def has_test_docs(self):
        """Whether the task has a test set"""
533
534
        pass

Leo Gao's avatar
Leo Gao committed
535
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
536
537
538
539
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
540
        return []
541

Leo Gao's avatar
Leo Gao committed
542
    def validation_docs(self):
543
544
545
546
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
547
        return []
548

Leo Gao's avatar
Leo Gao committed
549
    def test_docs(self):
550
551
552
553
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
554
        return []
Leo Gao's avatar
Leo Gao committed
555

Jon Tow's avatar
Jon Tow committed
556
557
558
559
    def _process_doc(self, doc):
        """
        Override this to process (detokenize, strip, replace, etc.) individual
        documents. This can be used in a map over documents of a data split.
Jon Tow's avatar
Jon Tow committed
560
        E.g. `map(self._process_doc, self.dataset["validation"])`
Jon Tow's avatar
Jon Tow committed
561
562
563
564
565
566

        :return: dict
            The processed version of the specified `doc`.
        """
        return doc

567
    def fewshot_examples(self, k, rnd):
568
569
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
570

Leo Gao's avatar
Leo Gao committed
571
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
572

573
    def doc_to_decontamination_query(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
574
575
576
577
        print(
            "Override doc_to_decontamination_query with document specific decontamination query."
        )
        assert False
578

579
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
580
581
582
    def doc_to_text(self, doc):
        pass

583
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
584
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
585
        pass
Leo Gao's avatar
Leo Gao committed
586

587
    @abstractmethod
588
    def construct_requests(self, doc, ctx):
Fabrizio Milo's avatar
Fabrizio Milo committed
589
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
590
591
        Requests which will be sent to the LM.

592
593
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
594
        :param ctx: str
Fabrizio Milo's avatar
Fabrizio Milo committed
595
            The context string, generated by fewshot_context. This includes the natural
596
            language description, as well as the few shot examples, and the question
Fabrizio Milo's avatar
Fabrizio Milo committed
597
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
598
        """
Leo Gao's avatar
Leo Gao committed
599
        pass
600

601
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
602
    def process_results(self, doc, results):
Fabrizio Milo's avatar
Fabrizio Milo committed
603
604
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
605
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
606
607
608
609
610

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
611
        """
Leo Gao's avatar
Leo Gao committed
612
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
613

614
    @abstractmethod
615
616
    def aggregation(self):
        """
&'s avatar
&amp; committed
617
        :returns: {str: [metric_score] -> float}
Fabrizio Milo's avatar
Fabrizio Milo committed
618
            A dictionary where keys are the names of submetrics and values are
&'s avatar
&amp; committed
619
            functions that aggregate a list of metric scores
620
621
622
        """
        pass

623
    @abstractmethod
624
625
626
    def higher_is_better(self):
        """
        :returns: {str: bool}
Fabrizio Milo's avatar
Fabrizio Milo committed
627
            A dictionary where keys are the names of submetrics and values are
628
629
630
631
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
632
    def fewshot_description(self):
633
        import warnings
Fabrizio Milo's avatar
Fabrizio Milo committed
634

635
        warnings.warn(
Jonathan Tow's avatar
Jonathan Tow committed
636
            "`fewshot_description` will be removed in futures versions. Pass "
637
            "any custom descriptions to the `evaluate` function instead.",
Fabrizio Milo's avatar
Fabrizio Milo committed
638
639
            DeprecationWarning,
        )
Jason Phang's avatar
checkin  
Jason Phang committed
640
641
        return ""

642
    @utils.positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
643
644
645
646
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        """Returns a fewshot context string that is made up of a prepended description
647
648
649
650
651
652
653
654
655
656
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
657
            WARNING: This is currently a required arg although it's optionalized with a default `None`.
658
659
660
661
662
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
Fabrizio Milo's avatar
Fabrizio Milo committed
663
664
665
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
666
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
667
            "The `provide_description` arg will be removed in future versions. To prepend "
668
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
669
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
670
        )
671
672
        if provide_description is not None:
            # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
673
674
675
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
676

677
        description = description + "\n\n" if description else ""
678

679
680
        if num_fewshot == 0:
            labeled_examples = ""
681
        else:
682
683
684
685
686
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
687
                    self._fewshot_docs = list(
Fabrizio Milo's avatar
Fabrizio Milo committed
688
689
690
                        self.validation_docs()
                        if self.has_validation_docs()
                        else self.test_docs()
691
                    )
692

693
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
694

695
696
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
697

Fabrizio Milo's avatar
Fabrizio Milo committed
698
699
700
701
702
703
704
705
706
            labeled_examples = (
                "\n\n".join(
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex
                    ]
                )
                + "\n\n"
            )
Leo Gao's avatar
Update  
Leo Gao committed
707

708
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
709
710
711
        return description + labeled_examples + example


Jon Tow's avatar
Jon Tow committed
712
class MultipleChoiceTask(Task):
Leo Gao's avatar
Leo Gao committed
713
    def doc_to_target(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
714
        return " " + doc["choices"][doc["gold"]]
Leo Gao's avatar
Leo Gao committed
715

Leo Gao's avatar
Leo Gao committed
716
717
    def construct_requests(self, doc, ctx):
        lls = [
Fabrizio Milo's avatar
Fabrizio Milo committed
718
            rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
Leo Gao's avatar
Leo Gao committed
719
720
721
722
723
724
725
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Fabrizio Milo's avatar
Fabrizio Milo committed
726
        acc = 1.0 if np.argmax(results) == gold else 0.0
727
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Fabrizio Milo's avatar
Fabrizio Milo committed
728
        acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
Leo Gao's avatar
Leo Gao committed
729
730

        return {
Leo Gao's avatar
Leo Gao committed
731
732
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
733
        }
Fabrizio Milo's avatar
Fabrizio Milo committed
734

Leo Gao's avatar
Leo Gao committed
735
736
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
737
738
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
739
        }
Fabrizio Milo's avatar
Fabrizio Milo committed
740

Leo Gao's avatar
Leo Gao committed
741
742
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
743
744
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
745
746
747
        }


Jason Phang's avatar
Jason Phang committed
748
class PerplexityTask(Task, abc.ABC):
749
750
751
752
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return True

Jason Phang's avatar
Jason Phang committed
753
754
755
756
757
758
759
    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

Fabrizio Milo's avatar
Fabrizio Milo committed
760
761
762
763
764
765
766
767
768
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        assert (
            num_fewshot == 0
        ), "The number of fewshot examples must be 0 for perplexity tasks."
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
769
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
770
            "The `provide_description` arg will be removed in future versions. To prepend "
Jonathan Tow's avatar
Jonathan Tow committed
771
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
772
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
773
        )
774
775
        if provide_description is not None:
            # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
776
777
778
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
779

Jason Phang's avatar
Jason Phang committed
780
781
782
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
783
784
785
786
787
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
788

789
790
791
    def doc_to_decontamination_query(self, doc):
        return doc

Jason Phang's avatar
Jason Phang committed
792
    def doc_to_text(self, doc):
793
        return ""
Jason Phang's avatar
Jason Phang committed
794
795

    def doc_to_target(self, doc):
796
        return doc
Jason Phang's avatar
Jason Phang committed
797
798
799

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
800
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
801
802
803
        return req

    def process_results(self, doc, results):
Fabrizio Milo's avatar
Fabrizio Milo committed
804
        (loglikelihood,) = results
Leo Gao's avatar
Leo Gao committed
805
        words = self.count_words(doc)
806
        bytes_ = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
807
        return {
Leo Gao's avatar
Leo Gao committed
808
            "word_perplexity": (loglikelihood, words),
809
            "byte_perplexity": (loglikelihood, bytes_),
810
            "bits_per_byte": (loglikelihood, bytes_),
Jason Phang's avatar
Jason Phang committed
811
812
813
814
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
815
816
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
817
            "bits_per_byte": bits_per_byte,
Jason Phang's avatar
Jason Phang committed
818
819
        }

820
821
    @classmethod
    def count_bytes(cls, doc):
Leo Gao's avatar
Leo Gao committed
822
        return len(doc.encode("utf-8"))
823
824
825

    @classmethod
    def count_words(cls, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
826
        """Downstream tasks with custom word boundaries should override this!"""
Leo Gao's avatar
Leo Gao committed
827
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
828

Jason Phang's avatar
Jason Phang committed
829

Leo Gao's avatar
Leo Gao committed
830
831
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
Fabrizio Milo's avatar
Fabrizio Milo committed
832
    return hashlib.sha256(dat.encode("utf-8")).hexdigest()
Leo Gao's avatar
Leo Gao committed
833
834


Leo Gao's avatar
Leo Gao committed
835
836
class CacheHook:
    def __init__(self, cachinglm):
Fabrizio Milo's avatar
Fabrizio Milo committed
837
        if cachinglm is None:
Leo Gao's avatar
Leo Gao committed
838
839
840
841
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
Fabrizio Milo's avatar
Fabrizio Milo committed
842

Leo Gao's avatar
Leo Gao committed
843
844
845
846
847
848
849
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
850
851
class CachingLM:
    def __init__(self, lm, cache_db):
852
853
854
855
856
857
858
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
Leo Gao's avatar
Leo Gao committed
859
860
        self.lm = lm
        self.cache_db = cache_db
861
862
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
863
864
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
865
866
867
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
868
869
870
871
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
Fabrizio Milo's avatar
Fabrizio Milo committed
872

Leo Gao's avatar
Leo Gao committed
873
874
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
875
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
876
877
878
879
880
881
882
883
884
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
Fabrizio Milo's avatar
Fabrizio Milo committed
885

886
            # actually run the LM on the requests that do not have cached results
Leo Gao's avatar
Leo Gao committed
887
888
889
890
891
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
892
893
                while res[resptr] is not None:
                    resptr += 1
Leo Gao's avatar
Leo Gao committed
894
895
896
897

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
898
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
899
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
900
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
901
902

            return res
Fabrizio Milo's avatar
Fabrizio Milo committed
903

Leo Gao's avatar
Leo Gao committed
904
        return fn
Fabrizio Milo's avatar
Fabrizio Milo committed
905

Leo Gao's avatar
Leo Gao committed
906
907
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
908

Jason Phang's avatar
Jason Phang committed
909

910
REQUEST_RETURN_LENGTHS = {
Fabrizio Milo's avatar
Fabrizio Milo committed
911
912
913
    "loglikelihood": 2,
    "greedy_until": None,
    "loglikelihood_rolling": None,
914
915
916
}


917
class Request:
Leo Gao's avatar
Leo Gao committed
918
919
    def __init__(self, request_type, args, index=None):
        if request_type not in REQUEST_RETURN_LENGTHS.keys():
Fabrizio Milo's avatar
Fabrizio Milo committed
920
921
922
            raise NotImplementedError(
                "The request type {} is not implemented!".format(request_type)
            )
Leo Gao's avatar
Leo Gao committed
923

Leo Gao's avatar
Leo Gao committed
924
        self.request_type = request_type
925
926
        self.args = args
        self.index = index
Fabrizio Milo's avatar
Fabrizio Milo committed
927

928
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
929
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Fabrizio Milo's avatar
Fabrizio Milo committed
930
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
931
932
        for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
            yield Request(self.request_type, self.args, i)
Fabrizio Milo's avatar
Fabrizio Milo committed
933

934
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
935
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Fabrizio Milo's avatar
Fabrizio Milo committed
936
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
937
        return Request(self.request_type, self.args, i)
Fabrizio Milo's avatar
Fabrizio Milo committed
938

Leo Gao's avatar
Leo Gao committed
939
    def __eq__(self, other):
Fabrizio Milo's avatar
Fabrizio Milo committed
940
941
942
943
944
        return (
            self.request_type == other.request_type
            and self.args == other.args
            and self.index == other.index
        )
Leo Gao's avatar
Leo Gao committed
945

Leo Gao's avatar
Leo Gao committed
946
    def __repr__(self):
Leo Gao's avatar
Leo Gao committed
947
        return f"Req_{self.request_type}{self.args}[{self.index}]\n"
948

Jason Phang's avatar
Jason Phang committed
949

Leo Gao's avatar
Leo Gao committed
950
951
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
952
953
        def fn(*args):
            return Request(attr, args)
Fabrizio Milo's avatar
Fabrizio Milo committed
954

Leo Gao's avatar
Leo Gao committed
955
956
957
958
        return fn


rf = RequestFactory()