base.py 32.1 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import abc
2
from typing import Iterable
cjlovering's avatar
cjlovering committed
3

thefazzer's avatar
thefazzer committed
4
import numpy as np
5
import random
Leo Gao's avatar
Leo Gao committed
6
import re
7
8
9
import os
import json
import hashlib
Jonathan Tow's avatar
Jonathan Tow committed
10
import datasets
11
from sqlitedict import SqliteDict
12
from tqdm import tqdm
13
import torch
Leo Gao's avatar
Leo Gao committed
14
import torch.nn.functional as F
&'s avatar
& committed
15

16
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
17
from lm_eval import utils
18
from abc import abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
19

Jason Phang's avatar
Jason Phang committed
20

Leo Gao's avatar
Leo Gao committed
21
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
22
23
24
    def __init__(self):
        self.cache_hook = CacheHook(None)

25
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
26
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
27
        """Compute log-likelihood of generating a continuation from a context.
cjlovering's avatar
cjlovering committed
28
        Downstream tasks should attempt to use loglikelihood instead of other
Leo Gao's avatar
Leo Gao committed
29
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
30

Leo Gao's avatar
Leo Gao committed
31
32
33
        :param requests: list
            A list of pairs (context, continuation)
            context: str
cjlovering's avatar
cjlovering committed
34
                Context string. Implementations of LM must be able to handle an
Leo Gao's avatar
Leo Gao committed
35
                empty context string.
Leo Gao's avatar
Leo Gao committed
36
            continuation: str
cjlovering's avatar
cjlovering committed
37
38
                The continuation over which log likelihood will be calculated. If
                there is a word boundary, the space should be in the continuation.
Leo Gao's avatar
Leo Gao committed
39
40
41
42
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
43
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
44
            isgreedy:
Jason Phang's avatar
Jason Phang committed
45
46
47
48
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

49
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
50
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
88
89
90
        """
        pass

&'s avatar
& committed
91
    # TODO: Add an optional max length
92
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
93
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
94
95
96
97
98
99
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
100
            until: [str]
cjlovering's avatar
cjlovering committed
101
                The string sequences to generate until. These string sequences
Leo Gao's avatar
Leo Gao committed
102
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
103
104
105
106
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
107
        """
Leo Gao's avatar
Leo Gao committed
108
109
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
110
    @classmethod
111
112
    def create_from_arg_string(cls, arg_string, additional_config=None):
        additional_config = {} if additional_config is None else additional_config
113
114
115
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
116

Leo Gao's avatar
Leo Gao committed
117
118
119
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
120

121
class BaseLM(LM):
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    @property
    @abstractmethod
    def eot_token_id(self):
        pass

    @property
    @abstractmethod
    def max_length(self):
        pass

    @property
    @abstractmethod
    def max_gen_toks(self):
        pass

    @property
    @abstractmethod
    def batch_size(self):
        pass

    @property
    @abstractmethod
    def device(self):
        pass

147
    @abstractmethod
cjlovering's avatar
cjlovering committed
148
149
150
    def tok_encode(self, string: str):
        pass

151
    @abstractmethod
cjlovering's avatar
cjlovering committed
152
153
    def tok_decode(self, tokens: Iterable[int]):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
154

155
    @abstractmethod
cjlovering's avatar
cjlovering committed
156
157
    def _model_generate(self, context, max_length, eos_token_id):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
158

159
160
    @abstractmethod
    def _model_call(self, inps):
Jason Phang's avatar
gpt3  
Jason Phang committed
161
        """
162
163
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Jason Phang's avatar
gpt3  
Jason Phang committed
164

165
        returns: a torch tensor of shape [batch, sequence, vocab] with the
166
        logits returned from the model
167
168
        """
        pass
169

Leo Gao's avatar
Leo Gao committed
170
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    # TODO: enforce this somehow

    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
cjlovering's avatar
cjlovering committed
193
194
195
196
197
198
199
200
201
202
203
204
        for (string,) in tqdm(requests):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
205
206
207

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

208
209
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
cjlovering's avatar
cjlovering committed
210
211
212
213
            string_nll = self._loglikelihood_tokens(
                rolling_token_windows, disable_tqdm=True
            )

214
215
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
cjlovering's avatar
cjlovering committed
216

217
218
219
220
221
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

222
223
224
225
226
227
228
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
229
230
231
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
232
233
234
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
235
            return -len(toks), tuple(toks)
cjlovering's avatar
cjlovering committed
236

237
238
        # TODO: automatic (variable) batch size detection for vectorization
        reord = utils.Reorderer(requests, _collate)
cjlovering's avatar
cjlovering committed
239
240
241
        for chunk in utils.chunks(
            tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size
        ):
242
            inps = []
243
            cont_toks_list = []
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
260
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
261
                # gpt2    \               \
262
263
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice
264
265
266

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
cjlovering's avatar
cjlovering committed
267
268
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
269
                ).to(self.device)
cjlovering's avatar
cjlovering committed
270
                (inplen,) = inp.shape
271
272
273
274

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
cjlovering's avatar
cjlovering committed
275
276
277
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )
278

279
                # pad length from seq to padding_length
cjlovering's avatar
cjlovering committed
280
281
282
283
284
285
286
287
288
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )
289

290
291
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
292
293
                inplens.append(inplen)

294
            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
cjlovering's avatar
cjlovering committed
295
296
297
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]
298

cjlovering's avatar
cjlovering committed
299
300
301
            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
            ):
302

303
304
                # Slice to original seq length
                contlen = len(cont_toks)
cjlovering's avatar
cjlovering committed
305
306
307
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]
308

309
                # Check if per-token argmax is exactly equal to continuation
310
                greedy_tokens = logits.argmax(dim=-1)
cjlovering's avatar
cjlovering committed
311
312
313
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
314
315
                max_equal = (greedy_tokens == cont_toks).all()

316
317
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
cjlovering's avatar
cjlovering committed
318
319
320
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]
321

322
                # Answer: (log prob, is-exact-match)
323
324
325
326
327
328
329
330
331
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

        return reord.get_original(res)
cjlovering's avatar
cjlovering committed
332

333
    def greedy_until(self, requests):
cjlovering's avatar
cjlovering committed
334
        # TODO: implement fully general `until` that handles untils that are
335
        #       multiple tokens or that span multiple tokens correctly
336
337
338
339
340
341

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
342
            return len(toks), x[0]
cjlovering's avatar
cjlovering committed
343

344
345
346
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
347
348
            if isinstance(until, str):
                until = [until]
349

cjlovering's avatar
cjlovering committed
350
351
352
353
354
            (primary_until,) = self.tok_encode(until[0])

            context_enc = torch.tensor(
                [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
            ).to(self.device)
355

cjlovering's avatar
cjlovering committed
356
357
358
            cont = self._model_generate(
                context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
            )
359

cjlovering's avatar
cjlovering committed
360
            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
361
362
363

            for term in until:
                s = s.split(term)[0]
cjlovering's avatar
cjlovering committed
364

365
366
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
cjlovering's avatar
cjlovering committed
367

368
            res.append(s)
cjlovering's avatar
cjlovering committed
369

370
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
371

Leo Gao's avatar
Leo Gao committed
372

373
class Task(abc.ABC):
&'s avatar
&amp; committed
374
375
376
377
378
379
380
381
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Jonathan Tow's avatar
Jonathan Tow committed
382

Jon Tow's avatar
Jon Tow committed
383
384
    # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
    # or a path to a custom `datasets` loading script.
Jonathan Tow's avatar
Jonathan Tow committed
385
386
387
388
389
    DATASET_PATH: str = None

    # The name of a subset within `DATASET_PATH`.
    DATASET_NAME: str = None

Jon Tow's avatar
Jon Tow committed
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        """
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
        """
        self.download(data_dir, cache_dir, download_mode)
414
        self._training_docs = None
415
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
416

Jon Tow's avatar
Jon Tow committed
417
    def download(self, data_dir=None, cache_dir=None, download_mode=None):
cjlovering's avatar
cjlovering committed
418
        """Downloads and returns the task dataset.
Jonathan Tow's avatar
Jonathan Tow committed
419
420
        Override this method to download the dataset from a custom API.

Jon Tow's avatar
Jon Tow committed
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
Jonathan Tow's avatar
Jonathan Tow committed
441
442
443
444
        """
        self.dataset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
Jon Tow's avatar
Jon Tow committed
445
446
            data_dir=data_dir,
            cache_dir=cache_dir,
cjlovering's avatar
cjlovering committed
447
            download_mode=download_mode,
Jonathan Tow's avatar
Jonathan Tow committed
448
        )
sdtblck's avatar
sdtblck committed
449

450
    @abstractmethod
451
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
452
        """Whether the task has a training set"""
453
        pass
454

455
    @abstractmethod
456
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
457
458
459
        """Whether the task has a validation set"""
        pass

460
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
461
462
    def has_test_docs(self):
        """Whether the task has a test set"""
463
464
        pass

Leo Gao's avatar
Leo Gao committed
465
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
466
467
468
469
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
470
        return []
471

Leo Gao's avatar
Leo Gao committed
472
    def validation_docs(self):
473
474
475
476
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
477
        return []
478

Leo Gao's avatar
Leo Gao committed
479
    def test_docs(self):
480
481
482
483
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
484
        return []
Leo Gao's avatar
Leo Gao committed
485

Jon Tow's avatar
Jon Tow committed
486
487
488
489
    def _process_doc(self, doc):
        """
        Override this to process (detokenize, strip, replace, etc.) individual
        documents. This can be used in a map over documents of a data split.
Jon Tow's avatar
Jon Tow committed
490
        E.g. `map(self._process_doc, self.dataset["validation"])`
Jon Tow's avatar
Jon Tow committed
491
492
493
494
495
496

        :return: dict
            The processed version of the specified `doc`.
        """
        return doc

497
    def fewshot_examples(self, k, rnd):
498
499
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
500

Leo Gao's avatar
Leo Gao committed
501
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
502

503
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
504
505
506
    def doc_to_text(self, doc):
        pass

507
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
508
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
509
        pass
Leo Gao's avatar
Leo Gao committed
510

511
    @abstractmethod
512
    def construct_requests(self, doc, ctx):
cjlovering's avatar
cjlovering committed
513
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
514
515
        Requests which will be sent to the LM.

516
517
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
518
        :param ctx: str
cjlovering's avatar
cjlovering committed
519
            The context string, generated by fewshot_context. This includes the natural
520
            language description, as well as the few shot examples, and the question
cjlovering's avatar
cjlovering committed
521
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
522
        """
Leo Gao's avatar
Leo Gao committed
523
        pass
524

525
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
526
    def process_results(self, doc, results):
cjlovering's avatar
cjlovering committed
527
528
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
529
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
530
531
532
533
534

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
535
        """
Leo Gao's avatar
Leo Gao committed
536
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
537

538
    @abstractmethod
539
540
    def aggregation(self):
        """
&'s avatar
&amp; committed
541
        :returns: {str: [metric_score] -> float}
cjlovering's avatar
cjlovering committed
542
            A dictionary where keys are the names of submetrics and values are
&'s avatar
&amp; committed
543
            functions that aggregate a list of metric scores
544
545
546
        """
        pass

547
    @abstractmethod
548
549
550
    def higher_is_better(self):
        """
        :returns: {str: bool}
cjlovering's avatar
cjlovering committed
551
            A dictionary where keys are the names of submetrics and values are
552
553
554
555
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
556
    def fewshot_description(self):
557
        import warnings
cjlovering's avatar
cjlovering committed
558

559
        warnings.warn(
Jonathan Tow's avatar
Jonathan Tow committed
560
            "`fewshot_description` will be removed in futures versions. Pass "
561
            "any custom descriptions to the `evaluate` function instead.",
cjlovering's avatar
cjlovering committed
562
563
            DeprecationWarning,
        )
Jason Phang's avatar
checkin  
Jason Phang committed
564
565
        return ""

566
    @utils.positional_deprecated
cjlovering's avatar
cjlovering committed
567
568
569
570
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        """Returns a fewshot context string that is made up of a prepended description
571
572
573
574
575
576
577
578
579
580
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
581
            WARNING: This is currently a required arg although it's optionalized with a default `None`.
582
583
584
585
586
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
cjlovering's avatar
cjlovering committed
587
588
589
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
590
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
591
            "The `provide_description` arg will be removed in future versions. To prepend "
592
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
593
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
594
        )
595
596
        if provide_description is not None:
            # nudge people to not specify it at all
cjlovering's avatar
cjlovering committed
597
598
599
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
600

601
        description = description + "\n\n" if description else ""
602

603
604
        if num_fewshot == 0:
            labeled_examples = ""
605
        else:
606
607
608
609
610
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
611
                    self._fewshot_docs = list(
cjlovering's avatar
cjlovering committed
612
613
614
                        self.validation_docs()
                        if self.has_validation_docs()
                        else self.test_docs()
615
                    )
616

617
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
618

619
620
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
621

cjlovering's avatar
cjlovering committed
622
623
624
625
626
627
628
629
630
            labeled_examples = (
                "\n\n".join(
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex
                    ]
                )
                + "\n\n"
            )
Leo Gao's avatar
Update  
Leo Gao committed
631

632
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
633
634
635
        return description + labeled_examples + example


cjlovering's avatar
cjlovering committed
636
637
638
639
class PromptSourceTask(Task):
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None):
        super().__init__(data_dir, cache_dir, download_mode)
        self.prompt = prompt
Jon Tow's avatar
Jon Tow committed
640

Leo Gao's avatar
Leo Gao committed
641
    def doc_to_target(self, doc):
cjlovering's avatar
cjlovering committed
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
        _, target = prompt.apply(doc)
        return f" {target}"

    def doc_to_text(self, doc):
        text, _ = prompt.apply(doc)
        return text

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
        """
        _requests = []

        if self.prompt.metadata.choices_in_prompt:
            for answer_choice in prompt.get_fixed_answer_choices_list():
                ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
                _requests.append(ll_answer_choice)
        else:
            # TODO(Albert): What is the stop symbol? Is it model specific?
            ll_greedy, _ = rf.greedy_until(ctx, ["\nQ:"])
            _requests.append(ll_greedy)

        return _requests

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
        raise NotImplementedError(
            "Implement process results using the `prompt.metadata.metrics`. See below."
        )
        if self.prompt.metadata.choices_in_prompt:
            for result, answer_choice in zip(
                prompt.get_fixed_answer_choices_list(), results
            ):
                pass
        else:
            continuation = results

        # Map metric name to HF metric.
        # TODO(Albert): What is Other?
        metric_names = prompt.metadata.metrics


class MultipleChoiceTask(Task):
    def doc_to_target(self, doc):
        return " " + doc["choices"][doc["gold"]]
Leo Gao's avatar
Leo Gao committed
702

Leo Gao's avatar
Leo Gao committed
703
704
    def construct_requests(self, doc, ctx):
        lls = [
cjlovering's avatar
cjlovering committed
705
            rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
Leo Gao's avatar
Leo Gao committed
706
707
708
709
710
711
712
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

cjlovering's avatar
cjlovering committed
713
        acc = 1.0 if np.argmax(results) == gold else 0.0
714
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
cjlovering's avatar
cjlovering committed
715
        acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
Leo Gao's avatar
Leo Gao committed
716
717

        return {
Leo Gao's avatar
Leo Gao committed
718
719
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
720
        }
cjlovering's avatar
cjlovering committed
721

Leo Gao's avatar
Leo Gao committed
722
723
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
724
725
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
726
        }
cjlovering's avatar
cjlovering committed
727

Leo Gao's avatar
Leo Gao committed
728
729
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
730
731
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
732
733
734
        }


Jason Phang's avatar
Jason Phang committed
735
736
737
738
739
740
741
742
class PerplexityTask(Task, abc.ABC):
    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

cjlovering's avatar
cjlovering committed
743
744
745
746
747
748
749
750
751
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        assert (
            num_fewshot == 0
        ), "The number of fewshot examples must be 0 for perplexity tasks."
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
752
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
753
            "The `provide_description` arg will be removed in future versions. To prepend "
Jonathan Tow's avatar
Jonathan Tow committed
754
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
755
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
756
        )
757
758
        if provide_description is not None:
            # nudge people to not specify it at all
cjlovering's avatar
cjlovering committed
759
760
761
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
762

Jason Phang's avatar
Jason Phang committed
763
764
765
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
766
767
768
769
770
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
771
772

    def doc_to_text(self, doc):
773
        return ""
Jason Phang's avatar
Jason Phang committed
774
775

    def doc_to_target(self, doc):
776
        return doc
Jason Phang's avatar
Jason Phang committed
777
778
779

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
780
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
781
782
783
        return req

    def process_results(self, doc, results):
cjlovering's avatar
cjlovering committed
784
        (loglikelihood,) = results
Leo Gao's avatar
Leo Gao committed
785
        words = self.count_words(doc)
786
        bytes_ = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
787
        return {
Leo Gao's avatar
Leo Gao committed
788
            "word_perplexity": (loglikelihood, words),
789
            "byte_perplexity": (loglikelihood, bytes_),
790
            "bits_per_byte": (loglikelihood, bytes_),
Jason Phang's avatar
Jason Phang committed
791
792
793
794
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
795
796
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
797
            "bits_per_byte": bits_per_byte,
Jason Phang's avatar
Jason Phang committed
798
799
        }

800
801
    @classmethod
    def count_bytes(cls, doc):
Leo Gao's avatar
Leo Gao committed
802
        return len(doc.encode("utf-8"))
803
804
805

    @classmethod
    def count_words(cls, doc):
cjlovering's avatar
cjlovering committed
806
        """Downstream tasks with custom word boundaries should override this!"""
Leo Gao's avatar
Leo Gao committed
807
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
808

Jason Phang's avatar
Jason Phang committed
809

Leo Gao's avatar
Leo Gao committed
810
811
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
cjlovering's avatar
cjlovering committed
812
    return hashlib.sha256(dat.encode("utf-8")).hexdigest()
Leo Gao's avatar
Leo Gao committed
813
814


Leo Gao's avatar
Leo Gao committed
815
816
class CacheHook:
    def __init__(self, cachinglm):
cjlovering's avatar
cjlovering committed
817
        if cachinglm is None:
Leo Gao's avatar
Leo Gao committed
818
819
820
821
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
cjlovering's avatar
cjlovering committed
822

Leo Gao's avatar
Leo Gao committed
823
824
825
826
827
828
829
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
830
831
class CachingLM:
    def __init__(self, lm, cache_db):
832
833
834
835
836
837
838
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
Leo Gao's avatar
Leo Gao committed
839
840
        self.lm = lm
        self.cache_db = cache_db
841
842
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
843
844
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
845
846
847
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
848
849
850
851
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
cjlovering's avatar
cjlovering committed
852

Leo Gao's avatar
Leo Gao committed
853
854
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
855
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
856
857
858
859
860
861
862
863
864
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
cjlovering's avatar
cjlovering committed
865

866
            # actually run the LM on the requests that do not have cached results
Leo Gao's avatar
Leo Gao committed
867
868
869
870
871
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
872
873
                while res[resptr] is not None:
                    resptr += 1
Leo Gao's avatar
Leo Gao committed
874
875
876
877

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
878
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
879
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
880
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
881
882

            return res
cjlovering's avatar
cjlovering committed
883

Leo Gao's avatar
Leo Gao committed
884
        return fn
cjlovering's avatar
cjlovering committed
885

Leo Gao's avatar
Leo Gao committed
886
887
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
888

Jason Phang's avatar
Jason Phang committed
889

890
REQUEST_RETURN_LENGTHS = {
cjlovering's avatar
cjlovering committed
891
892
893
    "loglikelihood": 2,
    "greedy_until": None,
    "loglikelihood_rolling": None,
894
895
896
}


897
class Request:
Leo Gao's avatar
Leo Gao committed
898
899
    def __init__(self, request_type, args, index=None):
        if request_type not in REQUEST_RETURN_LENGTHS.keys():
cjlovering's avatar
cjlovering committed
900
901
902
            raise NotImplementedError(
                "The request type {} is not implemented!".format(request_type)
            )
Leo Gao's avatar
Leo Gao committed
903

Leo Gao's avatar
Leo Gao committed
904
        self.request_type = request_type
905
906
        self.args = args
        self.index = index
cjlovering's avatar
cjlovering committed
907

908
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
909
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
cjlovering's avatar
cjlovering committed
910
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
911
912
        for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
            yield Request(self.request_type, self.args, i)
cjlovering's avatar
cjlovering committed
913

914
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
915
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
cjlovering's avatar
cjlovering committed
916
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
917
        return Request(self.request_type, self.args, i)
cjlovering's avatar
cjlovering committed
918

Leo Gao's avatar
Leo Gao committed
919
    def __eq__(self, other):
cjlovering's avatar
cjlovering committed
920
921
922
923
924
        return (
            self.request_type == other.request_type
            and self.args == other.args
            and self.index == other.index
        )
Leo Gao's avatar
Leo Gao committed
925

Leo Gao's avatar
Leo Gao committed
926
    def __repr__(self):
Leo Gao's avatar
Leo Gao committed
927
        return f"Req_{self.request_type}{self.args}[{self.index}]\n"
928

Jason Phang's avatar
Jason Phang committed
929

Leo Gao's avatar
Leo Gao committed
930
931
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
932
933
        def fn(*args):
            return Request(attr, args)
cjlovering's avatar
cjlovering committed
934

Leo Gao's avatar
Leo Gao committed
935
936
937
938
        return fn


rf = RequestFactory()