base.py 34.8 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import abc
2
from typing import Iterable
thefazzer's avatar
thefazzer committed
3
import numpy as np
4
import random
Leo Gao's avatar
Leo Gao committed
5
import re
6
7
8
import os
import json
import hashlib
Jonathan Tow's avatar
Jonathan Tow committed
9
import datasets
10
from sqlitedict import SqliteDict
11
from tqdm import tqdm
12
import torch
Leo Gao's avatar
Leo Gao committed
13
import torch.nn.functional as F
14
from accelerate import find_executable_batch_size
&'s avatar
& committed
15

16
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
17
from lm_eval import utils
18
from abc import abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
19

Jason Phang's avatar
Jason Phang committed
20

Leo Gao's avatar
Leo Gao committed
21
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
22
23
24
    def __init__(self):
        self.cache_hook = CacheHook(None)

25
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
26
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
27
        """Compute log-likelihood of generating a continuation from a context.
Fabrizio Milo's avatar
Fabrizio Milo committed
28
        Downstream tasks should attempt to use loglikelihood instead of other
Leo Gao's avatar
Leo Gao committed
29
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
30

Leo Gao's avatar
Leo Gao committed
31
32
33
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Fabrizio Milo's avatar
Fabrizio Milo committed
34
                Context string. Implementations of LM must be able to handle an
Leo Gao's avatar
Leo Gao committed
35
                empty context string.
Leo Gao's avatar
Leo Gao committed
36
            continuation: str
Fabrizio Milo's avatar
Fabrizio Milo committed
37
38
                The continuation over which log likelihood will be calculated. If
                there is a word boundary, the space should be in the continuation.
Leo Gao's avatar
Leo Gao committed
39
40
41
42
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
43
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
44
            isgreedy:
Jason Phang's avatar
Jason Phang committed
45
46
47
48
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

49
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
50
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
51
52
53
54
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
Fabrizio Milo's avatar
Fabrizio Milo committed
55
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
Jason Phang's avatar
Jason Phang committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
88
89
90
        """
        pass

&'s avatar
& committed
91
    # TODO: Add an optional max length
92
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
93
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
94
95
96
97
98
99
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
100
            until: [str]
Fabrizio Milo's avatar
Fabrizio Milo committed
101
                The string sequences to generate until. These string sequences
Leo Gao's avatar
Leo Gao committed
102
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
103
104
105
106
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
107
        """
Leo Gao's avatar
Leo Gao committed
108
109
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
110
    @classmethod
111
112
    def create_from_arg_string(cls, arg_string, additional_config=None):
        additional_config = {} if additional_config is None else additional_config
113
114
115
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
116

Leo Gao's avatar
Leo Gao committed
117
118
119
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
120

121
class BaseLM(LM):
122
123
124
125
126
127
    def __init__(self):
        super().__init__()
        self.batch_schedule = 1
        self.batch_sizes = {}
        self.max_batch_size = 512

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    @property
    @abstractmethod
    def eot_token_id(self):
        pass

    @property
    @abstractmethod
    def max_length(self):
        pass

    @property
    @abstractmethod
    def max_gen_toks(self):
        pass

    @property
    @abstractmethod
    def batch_size(self):
        pass

    @property
    @abstractmethod
    def device(self):
        pass

153
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
154
155
156
    def tok_encode(self, string: str):
        pass

157
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
158
159
    def tok_decode(self, tokens: Iterable[int]):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
160

161
    @abstractmethod
Fabrizio Milo's avatar
Fabrizio Milo committed
162
163
    def _model_generate(self, context, max_length, eos_token_id):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
164

165
166
    @abstractmethod
    def _model_call(self, inps):
Jason Phang's avatar
gpt3  
Jason Phang committed
167
        """
168
169
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Jason Phang's avatar
gpt3  
Jason Phang committed
170

171
        returns: a torch tensor of shape [batch, sequence, vocab] with the
172
        logits returned from the model
173
174
        """
        pass
175

176
177
178
    def _detect_batch_size(self, requests=None, pos=0):
        if requests:
            _, context_enc, continuation_enc = requests[pos]
179
180
181
            max_length = len(
                (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
            )
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        else:
            max_length = self.max_length

        # if OOM, then halves batch_size and tries again
        @find_executable_batch_size(starting_batch_size=self.max_batch_size)
        def forward_batch(batch_size):
            test_batch = torch.ones((batch_size, max_length), device=self.device).long()
            for _ in range(5):
                _ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
            return batch_size

        batch_size = forward_batch()
        utils.clear_torch_cache()

        return batch_size

Leo Gao's avatar
Leo Gao committed
198
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
199
200
    # TODO: enforce this somehow

gakada's avatar
gakada committed
201
    def _encode_pair(self, context, continuation):
202
203
204
205
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]
gakada's avatar
gakada committed
206
207
208
        whole_enc = self.tok_encode(context + continuation)
        context_enc = self.tok_encode(context)
        context_enc_len = len(context_enc)
209
        continuation_enc = whole_enc[context_enc_len:]
gakada's avatar
gakada committed
210
211
        return context_enc, continuation_enc

212
213
214
215
216
    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
217
218
219
                context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
                    continuation
                )
220
            else:
gakada's avatar
gakada committed
221
                context_enc, continuation_enc = self._encode_pair(context, continuation)
222
223
224
225
226
227
228

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
229

230
231
        # automatic batch size detection for vectorization
        adaptive_batch_size = None
232
        if self.batch_size == "auto":
233
            # using rolling window with maximum context
234
            print("Passed argument batch_size = auto. Detecting largest batch size")
235
            batch_size = self._detect_batch_size()
236
237
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size
238
239

        loglikelihoods = []
Fabrizio Milo's avatar
Fabrizio Milo committed
240
241
242
243
244
245
246
247
248
249
250
251
        for (string,) in tqdm(requests):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
252
253
254

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

255
256
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
Fabrizio Milo's avatar
Fabrizio Milo committed
257
            string_nll = self._loglikelihood_tokens(
258
259
260
                rolling_token_windows,
                disable_tqdm=True,
                override_bs=adaptive_batch_size,
Fabrizio Milo's avatar
Fabrizio Milo committed
261
262
            )

263
264
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
Fabrizio Milo's avatar
Fabrizio Milo committed
265

266
267
268
269
270
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

271
    def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs=None):
272
273
274
275
276
277
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
278
279
280
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
281
282
283
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
284
            return -len(toks), tuple(toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
285

Fabrizio Milo's avatar
Fabrizio Milo committed
286
        re_ord = utils.Reorderer(requests, _collate)
287

288
289
290
        reordered_requests = re_ord.get_reordered()
        n_reordered_requests = len(reordered_requests)

291
292
        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
293
294
295
296
        def _batch_scheduler(pos):
            sched = pos // int(n_reordered_requests / self.batch_schedule)
            if sched in self.batch_sizes:
                return self.batch_sizes[sched]
297
298
299
            print(
                f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
            )
300
301
302
            self.batch_sizes[sched] = self._detect_batch_size(reordered_requests, pos)
            print(f"Determined largest batch size: {self.batch_sizes[sched]}")
            return self.batch_sizes[sched]
303

Fabrizio Milo's avatar
Fabrizio Milo committed
304
        for chunk in utils.chunks(
305
            tqdm(reordered_requests, disable=disable_tqdm),
306
307
308
309
310
311
            n=self.batch_size
            if self.batch_size != "auto"
            else override_bs
            if override_bs is not None
            else 0,
            fn=_batch_scheduler
312
            if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs
313
            else None,
Fabrizio Milo's avatar
Fabrizio Milo committed
314
        ):
315
            inps = []
316
            cont_toks_list = []
317
            inplens = []
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
333
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
334
                # gpt2    \               \
335
336
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice
337
338
339

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
Fabrizio Milo's avatar
Fabrizio Milo committed
340
341
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
342
                ).to(self.device)
Fabrizio Milo's avatar
Fabrizio Milo committed
343
                (inplen,) = inp.shape
344
345
346
347

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
Fabrizio Milo's avatar
Fabrizio Milo committed
348
349
350
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )
351

352
                # pad length from seq to padding_length
Fabrizio Milo's avatar
Fabrizio Milo committed
353
354
355
356
357
358
359
360
361
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )
362

363
364
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
365
                inplens.append(inplen)
366

367
            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length]
Fabrizio Milo's avatar
Fabrizio Milo committed
368
369
370
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]
371

372
373
            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
Fabrizio Milo's avatar
Fabrizio Milo committed
374
            ):
375

376
377
                # Slice to original seq length
                contlen = len(cont_toks)
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
378
                inplen = inplen + (logits.shape[0] - padding_length) # if "virtual tokens" (from prompt tuning) are added, inplen is larger
Fabrizio Milo's avatar
Fabrizio Milo committed
379
380
381
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]
382

383
                # Check if per-token argmax is exactly equal to continuation
384
                greedy_tokens = logits.argmax(dim=-1)
Fabrizio Milo's avatar
Fabrizio Milo committed
385
386
387
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
388
389
                max_equal = (greedy_tokens == cont_toks).all()

390
391
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
Fabrizio Milo's avatar
Fabrizio Milo committed
392
393
394
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]
395

396
                # Answer: (log prob, is-exact-match)
397
398
399
400
401
402
403
404
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

Fabrizio Milo's avatar
Fabrizio Milo committed
405
        return re_ord.get_original(res)
Fabrizio Milo's avatar
Fabrizio Milo committed
406

407
    def greedy_until(self, requests):
Fabrizio Milo's avatar
Fabrizio Milo committed
408
        # TODO: implement fully general `until` that handles until that are
409
        #       multiple tokens or that span multiple tokens correctly
410
411
412
413
414

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
415
416
417
418
419
420
421
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

422
            toks = self.tok_encode(x[0])
423
            return -len(toks), x[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
424

Fabrizio Milo's avatar
Fabrizio Milo committed
425
        re_ord = utils.Reorderer(requests, _collate)
426

427
        warn_stop_seq = False
428
        for context, request_args in tqdm(re_ord.get_reordered()):
429
            until = request_args["until"]
430
431
            if isinstance(until, str):
                until = [until]
432

433
            if until:
434
435
436
437
438
439
440
441
442
                try:
                    (primary_until,) = self.tok_encode(until[0])
                except ValueError:
                    if not warn_stop_seq:
                        print(
                            "Warning: a primary stop sequence is multi-token! Will default to EOS token for this tokenizer. Consider using `hf-causal-experimental` for multi-token stop sequence support for the time being."
                        )
                        warn_stop_seq = True
                    primary_until = self.eot_token_id
443
444
            else:
                primary_until = None
445

Fabrizio Milo's avatar
Fabrizio Milo committed
446
447
448
            context_enc = torch.tensor(
                [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
            ).to(self.device)
449

450
            max_gen_tokens = min(
451
                self.max_gen_toks, request_args.get("max_length", self.max_gen_toks)
452
            )
Fabrizio Milo's avatar
Fabrizio Milo committed
453
            cont = self._model_generate(
454
                context_enc, context_enc.shape[1] + max_gen_tokens, primary_until
Fabrizio Milo's avatar
Fabrizio Milo committed
455
456
457
            )

            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
458
459
460

            for term in until:
                s = s.split(term)[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
461

462
463
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
Fabrizio Milo's avatar
Fabrizio Milo committed
464

465
            res.append(s)
Fabrizio Milo's avatar
Fabrizio Milo committed
466

Fabrizio Milo's avatar
Fabrizio Milo committed
467
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
468

Leo Gao's avatar
Leo Gao committed
469

470
class Task(abc.ABC):
&'s avatar
&amp; committed
471
472
473
474
475
476
477
478
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Jonathan Tow's avatar
Jonathan Tow committed
479

Jon Tow's avatar
Jon Tow committed
480
481
    # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
    # or a path to a custom `datasets` loading script.
Jonathan Tow's avatar
Jonathan Tow committed
482
483
484
485
486
    DATASET_PATH: str = None

    # The name of a subset within `DATASET_PATH`.
    DATASET_NAME: str = None

Jon Tow's avatar
Jon Tow committed
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        """
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
        """
        self.download(data_dir, cache_dir, download_mode)
511
        self._training_docs = None
512
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
513

Jon Tow's avatar
Jon Tow committed
514
    def download(self, data_dir=None, cache_dir=None, download_mode=None):
Fabrizio Milo's avatar
Fabrizio Milo committed
515
        """Downloads and returns the task dataset.
Jonathan Tow's avatar
Jonathan Tow committed
516
517
        Override this method to download the dataset from a custom API.

Jon Tow's avatar
Jon Tow committed
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
Jonathan Tow's avatar
Jonathan Tow committed
538
539
540
541
        """
        self.dataset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
Jon Tow's avatar
Jon Tow committed
542
543
            data_dir=data_dir,
            cache_dir=cache_dir,
Fabrizio Milo's avatar
Fabrizio Milo committed
544
            download_mode=download_mode,
Jonathan Tow's avatar
Jonathan Tow committed
545
        )
sdtblck's avatar
sdtblck committed
546

547
548
549
550
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return False

551
    @abstractmethod
552
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
553
        """Whether the task has a training set"""
554
        pass
555

556
    @abstractmethod
557
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
558
559
560
        """Whether the task has a validation set"""
        pass

561
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
562
563
    def has_test_docs(self):
        """Whether the task has a test set"""
564
565
        pass

Leo Gao's avatar
Leo Gao committed
566
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
567
568
569
570
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
571
        return []
572

Leo Gao's avatar
Leo Gao committed
573
    def validation_docs(self):
574
575
576
577
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
578
        return []
579

Leo Gao's avatar
Leo Gao committed
580
    def test_docs(self):
581
582
583
584
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
585
        return []
Leo Gao's avatar
Leo Gao committed
586

Jon Tow's avatar
Jon Tow committed
587
588
589
590
    def _process_doc(self, doc):
        """
        Override this to process (detokenize, strip, replace, etc.) individual
        documents. This can be used in a map over documents of a data split.
Jon Tow's avatar
Jon Tow committed
591
        E.g. `map(self._process_doc, self.dataset["validation"])`
Jon Tow's avatar
Jon Tow committed
592
593
594
595
596
597

        :return: dict
            The processed version of the specified `doc`.
        """
        return doc

598
    def fewshot_examples(self, k, rnd):
599
600
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
601

Leo Gao's avatar
Leo Gao committed
602
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
603

604
    def doc_to_decontamination_query(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
605
606
607
608
        print(
            "Override doc_to_decontamination_query with document specific decontamination query."
        )
        assert False
609

610
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
611
612
613
    def doc_to_text(self, doc):
        pass

614
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
615
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
616
        pass
Leo Gao's avatar
Leo Gao committed
617

618
    @abstractmethod
619
    def construct_requests(self, doc, ctx):
Fabrizio Milo's avatar
Fabrizio Milo committed
620
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
621
622
        Requests which will be sent to the LM.

623
624
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
625
        :param ctx: str
Fabrizio Milo's avatar
Fabrizio Milo committed
626
            The context string, generated by fewshot_context. This includes the natural
627
            language description, as well as the few shot examples, and the question
Fabrizio Milo's avatar
Fabrizio Milo committed
628
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
629
        """
Leo Gao's avatar
Leo Gao committed
630
        pass
631

632
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
633
    def process_results(self, doc, results):
Fabrizio Milo's avatar
Fabrizio Milo committed
634
635
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
636
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
637
638
639
640
641

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
642
        """
Leo Gao's avatar
Leo Gao committed
643
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
644

645
    @abstractmethod
646
647
    def aggregation(self):
        """
&'s avatar
&amp; committed
648
        :returns: {str: [metric_score] -> float}
Fabrizio Milo's avatar
Fabrizio Milo committed
649
            A dictionary where keys are the names of submetrics and values are
&'s avatar
&amp; committed
650
            functions that aggregate a list of metric scores
651
652
653
        """
        pass

654
    @abstractmethod
655
656
657
    def higher_is_better(self):
        """
        :returns: {str: bool}
Fabrizio Milo's avatar
Fabrizio Milo committed
658
            A dictionary where keys are the names of submetrics and values are
659
660
661
662
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
663
    def fewshot_description(self):
664
        import warnings
Fabrizio Milo's avatar
Fabrizio Milo committed
665

666
        warnings.warn(
Jonathan Tow's avatar
Jonathan Tow committed
667
            "`fewshot_description` will be removed in futures versions. Pass "
668
            "any custom descriptions to the `evaluate` function instead.",
Fabrizio Milo's avatar
Fabrizio Milo committed
669
670
            DeprecationWarning,
        )
Jason Phang's avatar
checkin  
Jason Phang committed
671
672
        return ""

673
    @utils.positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
674
675
676
677
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        """Returns a fewshot context string that is made up of a prepended description
678
679
680
681
682
683
684
685
686
687
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
688
            WARNING: This is currently a required arg although it's optionalized with a default `None`.
689
690
691
692
693
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
Fabrizio Milo's avatar
Fabrizio Milo committed
694
695
696
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
697
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
698
            "The `provide_description` arg will be removed in future versions. To prepend "
699
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
700
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
701
        )
702
703
        if provide_description is not None:
            # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
704
705
706
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
707

708
        description = description + "\n\n" if description else ""
709

710
711
        if num_fewshot == 0:
            labeled_examples = ""
712
        else:
713
714
715
716
717
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
718
                    self._fewshot_docs = list(
Fabrizio Milo's avatar
Fabrizio Milo committed
719
720
721
                        self.validation_docs()
                        if self.has_validation_docs()
                        else self.test_docs()
722
                    )
723

724
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
725

726
727
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
728

Fabrizio Milo's avatar
Fabrizio Milo committed
729
730
731
732
733
734
735
736
737
            labeled_examples = (
                "\n\n".join(
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex
                    ]
                )
                + "\n\n"
            )
Leo Gao's avatar
Update  
Leo Gao committed
738

739
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
740
741
742
        return description + labeled_examples + example


Jon Tow's avatar
Jon Tow committed
743
class MultipleChoiceTask(Task):
Leo Gao's avatar
Leo Gao committed
744
    def doc_to_target(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
745
        return " " + doc["choices"][doc["gold"]]
Leo Gao's avatar
Leo Gao committed
746

Leo Gao's avatar
Leo Gao committed
747
748
    def construct_requests(self, doc, ctx):
        lls = [
Fabrizio Milo's avatar
Fabrizio Milo committed
749
            rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
Leo Gao's avatar
Leo Gao committed
750
751
752
753
754
755
756
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Fabrizio Milo's avatar
Fabrizio Milo committed
757
        acc = 1.0 if np.argmax(results) == gold else 0.0
758
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Fabrizio Milo's avatar
Fabrizio Milo committed
759
        acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
Leo Gao's avatar
Leo Gao committed
760
761

        return {
Leo Gao's avatar
Leo Gao committed
762
763
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
764
        }
Fabrizio Milo's avatar
Fabrizio Milo committed
765

Leo Gao's avatar
Leo Gao committed
766
767
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
768
769
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
770
        }
Fabrizio Milo's avatar
Fabrizio Milo committed
771

Leo Gao's avatar
Leo Gao committed
772
773
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
774
775
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
776
777
778
        }


Jason Phang's avatar
Jason Phang committed
779
class PerplexityTask(Task, abc.ABC):
780
781
782
783
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return True

Jason Phang's avatar
Jason Phang committed
784
785
786
787
788
789
790
    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

Fabrizio Milo's avatar
Fabrizio Milo committed
791
792
793
794
795
796
797
798
799
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        assert (
            num_fewshot == 0
        ), "The number of fewshot examples must be 0 for perplexity tasks."
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
800
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
801
            "The `provide_description` arg will be removed in future versions. To prepend "
Jonathan Tow's avatar
Jonathan Tow committed
802
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
803
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
804
        )
805
806
        if provide_description is not None:
            # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
807
808
809
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
810

Jason Phang's avatar
Jason Phang committed
811
812
813
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
814
815
816
817
818
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
819

820
821
822
    def doc_to_decontamination_query(self, doc):
        return doc

Jason Phang's avatar
Jason Phang committed
823
    def doc_to_text(self, doc):
824
        return ""
Jason Phang's avatar
Jason Phang committed
825
826

    def doc_to_target(self, doc):
827
        return doc
Jason Phang's avatar
Jason Phang committed
828
829
830

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
831
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
832
833
834
        return req

    def process_results(self, doc, results):
Fabrizio Milo's avatar
Fabrizio Milo committed
835
        (loglikelihood,) = results
Leo Gao's avatar
Leo Gao committed
836
        words = self.count_words(doc)
837
        bytes_ = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
838
        return {
Leo Gao's avatar
Leo Gao committed
839
            "word_perplexity": (loglikelihood, words),
840
            "byte_perplexity": (loglikelihood, bytes_),
841
            "bits_per_byte": (loglikelihood, bytes_),
Jason Phang's avatar
Jason Phang committed
842
843
844
845
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
846
847
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
848
            "bits_per_byte": bits_per_byte,
Jason Phang's avatar
Jason Phang committed
849
850
        }

851
852
    @classmethod
    def count_bytes(cls, doc):
Leo Gao's avatar
Leo Gao committed
853
        return len(doc.encode("utf-8"))
854
855
856

    @classmethod
    def count_words(cls, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
857
        """Downstream tasks with custom word boundaries should override this!"""
Leo Gao's avatar
Leo Gao committed
858
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
859

Jason Phang's avatar
Jason Phang committed
860

Leo Gao's avatar
Leo Gao committed
861
862
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
Fabrizio Milo's avatar
Fabrizio Milo committed
863
    return hashlib.sha256(dat.encode("utf-8")).hexdigest()
Leo Gao's avatar
Leo Gao committed
864
865


Leo Gao's avatar
Leo Gao committed
866
867
class CacheHook:
    def __init__(self, cachinglm):
Fabrizio Milo's avatar
Fabrizio Milo committed
868
        if cachinglm is None:
Leo Gao's avatar
Leo Gao committed
869
870
871
872
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
Fabrizio Milo's avatar
Fabrizio Milo committed
873

Leo Gao's avatar
Leo Gao committed
874
875
876
877
878
879
880
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
881
882
class CachingLM:
    def __init__(self, lm, cache_db):
883
884
885
886
887
888
889
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
Leo Gao's avatar
Leo Gao committed
890
891
        self.lm = lm
        self.cache_db = cache_db
892
893
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
894
895
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
896
897
898
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
899
    def __getattr__(self, attr):
gk's avatar
gk committed
900
901
902
903
        lm_attr = getattr(self.lm, attr)
        if not callable(lm_attr):
            return lm_attr

Leo Gao's avatar
Leo Gao committed
904
905
906
        def fn(requests):
            res = []
            remaining_reqs = []
Fabrizio Milo's avatar
Fabrizio Milo committed
907

Leo Gao's avatar
Leo Gao committed
908
909
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
910
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
911
912
913
914
915
916
917
918
919
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
Fabrizio Milo's avatar
Fabrizio Milo committed
920

921
            # actually run the LM on the requests that do not have cached results
Leo Gao's avatar
Leo Gao committed
922
923
924
925
926
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
927
928
                while res[resptr] is not None:
                    resptr += 1
Leo Gao's avatar
Leo Gao committed
929
930
931
932

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
933
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
934
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
935
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
936
937

            return res
Fabrizio Milo's avatar
Fabrizio Milo committed
938

Leo Gao's avatar
Leo Gao committed
939
        return fn
Fabrizio Milo's avatar
Fabrizio Milo committed
940

Leo Gao's avatar
Leo Gao committed
941
942
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
943

Jason Phang's avatar
Jason Phang committed
944

945
REQUEST_RETURN_LENGTHS = {
Fabrizio Milo's avatar
Fabrizio Milo committed
946
947
948
    "loglikelihood": 2,
    "greedy_until": None,
    "loglikelihood_rolling": None,
949
950
951
}


952
class Request:
Leo Gao's avatar
Leo Gao committed
953
954
    def __init__(self, request_type, args, index=None):
        if request_type not in REQUEST_RETURN_LENGTHS.keys():
Fabrizio Milo's avatar
Fabrizio Milo committed
955
956
957
            raise NotImplementedError(
                "The request type {} is not implemented!".format(request_type)
            )
Leo Gao's avatar
Leo Gao committed
958

Leo Gao's avatar
Leo Gao committed
959
        self.request_type = request_type
960
961
        self.args = args
        self.index = index
Fabrizio Milo's avatar
Fabrizio Milo committed
962

963
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
964
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Fabrizio Milo's avatar
Fabrizio Milo committed
965
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
966
967
        for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
            yield Request(self.request_type, self.args, i)
Fabrizio Milo's avatar
Fabrizio Milo committed
968

969
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
970
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Fabrizio Milo's avatar
Fabrizio Milo committed
971
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
972
        return Request(self.request_type, self.args, i)
Fabrizio Milo's avatar
Fabrizio Milo committed
973

Leo Gao's avatar
Leo Gao committed
974
    def __eq__(self, other):
Fabrizio Milo's avatar
Fabrizio Milo committed
975
976
977
978
979
        return (
            self.request_type == other.request_type
            and self.args == other.args
            and self.index == other.index
        )
Leo Gao's avatar
Leo Gao committed
980

Leo Gao's avatar
Leo Gao committed
981
    def __repr__(self):
Leo Gao's avatar
Leo Gao committed
982
        return f"Req_{self.request_type}{self.args}[{self.index}]\n"
983

Jason Phang's avatar
Jason Phang committed
984

Leo Gao's avatar
Leo Gao committed
985
986
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
987
988
        def fn(*args):
            return Request(attr, args)
Fabrizio Milo's avatar
Fabrizio Milo committed
989

Leo Gao's avatar
Leo Gao committed
990
991
992
993
        return fn


rf = RequestFactory()