base.py 33.6 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import abc
2
from typing import Iterable
cjlovering's avatar
cjlovering committed
3

4
import promptsource
thefazzer's avatar
thefazzer committed
5
import numpy as np
6
import random
Leo Gao's avatar
Leo Gao committed
7
import re
8
9
10
import os
import json
import hashlib
Jonathan Tow's avatar
Jonathan Tow committed
11
import datasets
12
from sqlitedict import SqliteDict
13
from tqdm import tqdm
14
import torch
Leo Gao's avatar
Leo Gao committed
15
import torch.nn.functional as F
&'s avatar
& committed
16

17
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
18
from lm_eval import utils
19
from abc import abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
20

Jason Phang's avatar
Jason Phang committed
21

Leo Gao's avatar
Leo Gao committed
22
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
23
24
25
    def __init__(self):
        self.cache_hook = CacheHook(None)

26
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
27
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
28
        """Compute log-likelihood of generating a continuation from a context.
cjlovering's avatar
cjlovering committed
29
        Downstream tasks should attempt to use loglikelihood instead of other
Leo Gao's avatar
Leo Gao committed
30
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
31

Leo Gao's avatar
Leo Gao committed
32
33
34
        :param requests: list
            A list of pairs (context, continuation)
            context: str
cjlovering's avatar
cjlovering committed
35
                Context string. Implementations of LM must be able to handle an
Leo Gao's avatar
Leo Gao committed
36
                empty context string.
Leo Gao's avatar
Leo Gao committed
37
            continuation: str
cjlovering's avatar
cjlovering committed
38
39
                The continuation over which log likelihood will be calculated. If
                there is a word boundary, the space should be in the continuation.
Leo Gao's avatar
Leo Gao committed
40
41
42
43
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
44
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
45
            isgreedy:
Jason Phang's avatar
Jason Phang committed
46
47
48
49
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

50
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
51
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
89
90
91
        """
        pass

&'s avatar
& committed
92
    # TODO: Add an optional max length
93
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
94
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
95
96
97
98
99
100
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
101
            until: [str]
cjlovering's avatar
cjlovering committed
102
                The string sequences to generate until. These string sequences
Leo Gao's avatar
Leo Gao committed
103
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
104
105
106
107
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
108
        """
Leo Gao's avatar
Leo Gao committed
109
110
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
111
    @classmethod
112
113
    def create_from_arg_string(cls, arg_string, additional_config=None):
        additional_config = {} if additional_config is None else additional_config
114
115
116
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
117

Leo Gao's avatar
Leo Gao committed
118
119
120
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
121

122
class BaseLM(LM):
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    @property
    @abstractmethod
    def eot_token_id(self):
        pass

    @property
    @abstractmethod
    def max_length(self):
        pass

    @property
    @abstractmethod
    def max_gen_toks(self):
        pass

    @property
    @abstractmethod
    def batch_size(self):
        pass

    @property
    @abstractmethod
    def device(self):
        pass

148
    @abstractmethod
cjlovering's avatar
cjlovering committed
149
150
151
    def tok_encode(self, string: str):
        pass

152
    @abstractmethod
cjlovering's avatar
cjlovering committed
153
154
    def tok_decode(self, tokens: Iterable[int]):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
155

156
    @abstractmethod
cjlovering's avatar
cjlovering committed
157
158
    def _model_generate(self, context, max_length, eos_token_id):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
159

160
161
    @abstractmethod
    def _model_call(self, inps):
Jason Phang's avatar
gpt3  
Jason Phang committed
162
        """
163
164
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Jason Phang's avatar
gpt3  
Jason Phang committed
165

166
        returns: a torch tensor of shape [batch, sequence, vocab] with the
167
        logits returned from the model
168
169
        """
        pass
170

Leo Gao's avatar
Leo Gao committed
171
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    # TODO: enforce this somehow

    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
cjlovering's avatar
cjlovering committed
194
195
196
197
198
199
200
201
202
203
204
205
        for (string,) in tqdm(requests):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
206
207
208

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

209
210
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
cjlovering's avatar
cjlovering committed
211
212
213
214
            string_nll = self._loglikelihood_tokens(
                rolling_token_windows, disable_tqdm=True
            )

215
216
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
cjlovering's avatar
cjlovering committed
217

218
219
220
221
222
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

223
224
225
226
227
228
229
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
230
231
232
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
233
234
235
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
236
            return -len(toks), tuple(toks)
cjlovering's avatar
cjlovering committed
237

238
239
        # TODO: automatic (variable) batch size detection for vectorization
        reord = utils.Reorderer(requests, _collate)
cjlovering's avatar
cjlovering committed
240
241
242
        for chunk in utils.chunks(
            tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size
        ):
243
            inps = []
244
            cont_toks_list = []
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
261
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
262
                # gpt2    \               \
263
264
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice
265
266
267

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
cjlovering's avatar
cjlovering committed
268
269
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
270
                ).to(self.device)
cjlovering's avatar
cjlovering committed
271
                (inplen,) = inp.shape
272
273
274
275

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
cjlovering's avatar
cjlovering committed
276
277
278
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )
279

280
                # pad length from seq to padding_length
cjlovering's avatar
cjlovering committed
281
282
283
284
285
286
287
288
289
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )
290

291
292
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
293
294
                inplens.append(inplen)

295
            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
cjlovering's avatar
cjlovering committed
296
297
298
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]
299

cjlovering's avatar
cjlovering committed
300
301
302
            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
            ):
303

304
305
                # Slice to original seq length
                contlen = len(cont_toks)
cjlovering's avatar
cjlovering committed
306
307
308
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]
309

310
                # Check if per-token argmax is exactly equal to continuation
311
                greedy_tokens = logits.argmax(dim=-1)
cjlovering's avatar
cjlovering committed
312
313
314
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
315
316
                max_equal = (greedy_tokens == cont_toks).all()

317
318
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
cjlovering's avatar
cjlovering committed
319
320
321
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]
322

323
                # Answer: (log prob, is-exact-match)
324
325
326
327
328
329
330
331
332
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

        return reord.get_original(res)
cjlovering's avatar
cjlovering committed
333

334
    def greedy_until(self, requests):
cjlovering's avatar
cjlovering committed
335
        # TODO: implement fully general `until` that handles untils that are
336
        #       multiple tokens or that span multiple tokens correctly
337
338
339
340
341
342

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
343
            return len(toks), x[0]
cjlovering's avatar
cjlovering committed
344

345
346
347
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
348
349
            if isinstance(until, str):
                until = [until]
350

jon-tow's avatar
jon-tow committed
351
            # TODO: Come back to for generation `eos`.
cjlovering's avatar
cjlovering committed
352
            primary_until = self.tok_encode(until[0])
cjlovering's avatar
cjlovering committed
353
354
355
356

            context_enc = torch.tensor(
                [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
            ).to(self.device)
357

cjlovering's avatar
cjlovering committed
358
            cont = self._model_generate(
cjlovering's avatar
cjlovering committed
359
360
361
                context_enc,
                context_enc.shape[1] + self.max_gen_toks,
                torch.tensor(primary_until),
cjlovering's avatar
cjlovering committed
362
            )
363

cjlovering's avatar
cjlovering committed
364
            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
365
366
367

            for term in until:
                s = s.split(term)[0]
cjlovering's avatar
cjlovering committed
368

369
370
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
cjlovering's avatar
cjlovering committed
371

372
            res.append(s)
cjlovering's avatar
cjlovering committed
373

374
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
375

Leo Gao's avatar
Leo Gao committed
376

377
class Task(abc.ABC):
&'s avatar
&amp; committed
378
379
380
381
382
383
384
385
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Jonathan Tow's avatar
Jonathan Tow committed
386

Jon Tow's avatar
Jon Tow committed
387
388
    # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
    # or a path to a custom `datasets` loading script.
Jonathan Tow's avatar
Jonathan Tow committed
389
390
391
392
393
    DATASET_PATH: str = None

    # The name of a subset within `DATASET_PATH`.
    DATASET_NAME: str = None

Jon Tow's avatar
Jon Tow committed
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        """
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
        """
        self.download(data_dir, cache_dir, download_mode)
418
        self._training_docs = None
419
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
420

Jon Tow's avatar
Jon Tow committed
421
    def download(self, data_dir=None, cache_dir=None, download_mode=None):
cjlovering's avatar
cjlovering committed
422
        """Downloads and returns the task dataset.
Jonathan Tow's avatar
Jonathan Tow committed
423
424
        Override this method to download the dataset from a custom API.

Jon Tow's avatar
Jon Tow committed
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
Jonathan Tow's avatar
Jonathan Tow committed
445
446
447
448
        """
        self.dataset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
Jon Tow's avatar
Jon Tow committed
449
450
            data_dir=data_dir,
            cache_dir=cache_dir,
cjlovering's avatar
cjlovering committed
451
            download_mode=download_mode,
Jonathan Tow's avatar
Jonathan Tow committed
452
        )
sdtblck's avatar
sdtblck committed
453

454
    @abstractmethod
455
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
456
        """Whether the task has a training set"""
457
        pass
458

459
    @abstractmethod
460
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
461
462
463
        """Whether the task has a validation set"""
        pass

464
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
465
466
    def has_test_docs(self):
        """Whether the task has a test set"""
467
468
        pass

Leo Gao's avatar
Leo Gao committed
469
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
470
471
472
473
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
474
        return []
475

Leo Gao's avatar
Leo Gao committed
476
    def validation_docs(self):
477
478
479
480
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
481
        return []
482

Leo Gao's avatar
Leo Gao committed
483
    def test_docs(self):
484
485
486
487
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
488
        return []
Leo Gao's avatar
Leo Gao committed
489

Jon Tow's avatar
Jon Tow committed
490
491
492
493
    def _process_doc(self, doc):
        """
        Override this to process (detokenize, strip, replace, etc.) individual
        documents. This can be used in a map over documents of a data split.
Jon Tow's avatar
Jon Tow committed
494
        E.g. `map(self._process_doc, self.dataset["validation"])`
Jon Tow's avatar
Jon Tow committed
495
496
497
498
499
500

        :return: dict
            The processed version of the specified `doc`.
        """
        return doc

501
    def fewshot_examples(self, k, rnd):
502
503
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
504

Leo Gao's avatar
Leo Gao committed
505
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
506

507
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
508
509
510
    def doc_to_text(self, doc):
        pass

511
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
512
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
513
        pass
Leo Gao's avatar
Leo Gao committed
514

515
    @abstractmethod
516
    def construct_requests(self, doc, ctx):
cjlovering's avatar
cjlovering committed
517
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
518
519
        Requests which will be sent to the LM.

520
521
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
522
        :param ctx: str
cjlovering's avatar
cjlovering committed
523
            The context string, generated by fewshot_context. This includes the natural
524
            language description, as well as the few shot examples, and the question
cjlovering's avatar
cjlovering committed
525
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
526
        """
Leo Gao's avatar
Leo Gao committed
527
        pass
528

529
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
530
    def process_results(self, doc, results):
cjlovering's avatar
cjlovering committed
531
532
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
533
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
534
535
536
537
538

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
539
        """
Leo Gao's avatar
Leo Gao committed
540
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
541

542
    @abstractmethod
543
544
    def aggregation(self):
        """
&'s avatar
&amp; committed
545
        :returns: {str: [metric_score] -> float}
cjlovering's avatar
cjlovering committed
546
            A dictionary where keys are the names of submetrics and values are
&'s avatar
&amp; committed
547
            functions that aggregate a list of metric scores
548
549
550
        """
        pass

551
    @abstractmethod
552
553
554
    def higher_is_better(self):
        """
        :returns: {str: bool}
cjlovering's avatar
cjlovering committed
555
            A dictionary where keys are the names of submetrics and values are
556
557
558
559
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
560
    def fewshot_description(self):
561
        import warnings
cjlovering's avatar
cjlovering committed
562

563
        warnings.warn(
Jonathan Tow's avatar
Jonathan Tow committed
564
            "`fewshot_description` will be removed in futures versions. Pass "
565
            "any custom descriptions to the `evaluate` function instead.",
cjlovering's avatar
cjlovering committed
566
567
            DeprecationWarning,
        )
Jason Phang's avatar
checkin  
Jason Phang committed
568
569
        return ""

570
    @utils.positional_deprecated
cjlovering's avatar
cjlovering committed
571
572
573
574
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        """Returns a fewshot context string that is made up of a prepended description
575
576
577
578
579
580
581
582
583
584
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
585
            WARNING: This is currently a required arg although it's optionalized with a default `None`.
586
587
588
589
590
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
cjlovering's avatar
cjlovering committed
591
592
593
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
594
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
595
            "The `provide_description` arg will be removed in future versions. To prepend "
596
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
597
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
598
        )
599
600
        if provide_description is not None:
            # nudge people to not specify it at all
cjlovering's avatar
cjlovering committed
601
602
603
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
604

605
        description = description + "\n\n" if description else ""
606

607
608
        if num_fewshot == 0:
            labeled_examples = ""
609
        else:
610
611
612
613
614
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
615
                    self._fewshot_docs = list(
cjlovering's avatar
cjlovering committed
616
617
618
                        self.validation_docs()
                        if self.has_validation_docs()
                        else self.test_docs()
619
                    )
620

621
622
623
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
624

cjlovering's avatar
cjlovering committed
625
626
627
628
629
630
631
632
633
            labeled_examples = (
                "\n\n".join(
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex
                    ]
                )
                + "\n\n"
            )
Leo Gao's avatar
Update  
Leo Gao committed
634

635
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
636
637
638
        return description + labeled_examples + example


cjlovering's avatar
cjlovering committed
639
640
641
642
class PromptSourceTask(Task):
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None):
        super().__init__(data_dir, cache_dir, download_mode)
        self.prompt = prompt
Jon Tow's avatar
Jon Tow committed
643

jon-tow's avatar
jon-tow committed
644
645
646
    def eos_token(self):
        raise NotImplementedError()

647
648
649
650
651
652
    def is_generation_task(self):
        return (
            "BLEU" in self.prompt.metadata.metrics
            or "ROUGE" in self.prompt.metadata.metrics
        )

Leo Gao's avatar
Leo Gao committed
653
    def doc_to_target(self, doc):
jon-tow's avatar
jon-tow committed
654
        _, target = self.prompt.apply(doc)
cjlovering's avatar
cjlovering committed
655
656
657
        return f" {target}"

    def doc_to_text(self, doc):
jon-tow's avatar
jon-tow committed
658
        text, _ = self.prompt.apply(doc)
cjlovering's avatar
cjlovering committed
659
660
661
662
663
664
665
666
667
668
669
670
671
672
        return text

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
        """
        _requests = []
673
        answer_choices_list = self.prompt.get_answer_choices_list(doc)
674
675
676
677

        # We take a present answer_choices list to mean that we should apply the supplied
        # metrics (hardcoded or accuracy atm) to the ranked choices. Otherwise, assume generation.
        # Above we do something similar, but rely on the metrics requested (BLEU, ROUGE indicating generation).
678
        if answer_choices_list:
679
680
681
            assert (
                not self.is_generation_task()
            ), f"We expect this to be a ranked choice task; double check please."
682
            for answer_choice in answer_choices_list:
cjlovering's avatar
cjlovering committed
683
684
685
686
                ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
                _requests.append(ll_answer_choice)
        else:
            # TODO(Albert): What is the stop symbol? Is it model specific?
jon-tow's avatar
jon-tow committed
687
688
            cont_request = rf.greedy_until(ctx, [self.eos_token()])
            _requests.append(cont_request)
cjlovering's avatar
cjlovering committed
689
690
691
692
693
694
695
696
697
698
699
700
701

        return _requests

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
702
703
704
705
706
707
        # raise NotImplementedError(
        #     "Implement process results using the `prompt.metadata.metrics`. See below."
        # )
        target = self.doc_to_target(doc).strip()
        answer_choices_list = self.prompt.get_answer_choices_list(doc)
        if answer_choices_list:
708
709
710
            assert (
                not self.is_generation_task()
            ), f"We expect this to be a ranked choice task; double check please."
711
            pred = answer_choices_list[np.argmax(results)]
712
713
714
715
716
            out = {}
            if "Accuracy" in self.prompt.metadata.metrics:
                out["acc"] = pred == target
            # TODO: Add metrics here.
            return out
cjlovering's avatar
cjlovering committed
717
        else:
718
            raise NotImplementedError("Generation is not implemented yet.")
cjlovering's avatar
cjlovering committed
719
720
721

        # Map metric name to HF metric.
        # TODO(Albert): What is Other?
722
723
        # metric_names = prompt.metadata.metrics

724
    def higher_is_better(self):
725
726
727
728
729
        out = {}
        if "Accuracy" in self.prompt.metadata.metrics:
            out["acc"] = True

        return out
730
731

    def aggregation(self):
732
733
734
735
736
        out = {}
        if "Accuracy" in self.prompt.metadata.metrics:
            out["acc"] = mean

        return out
cjlovering's avatar
cjlovering committed
737
738
739
740
741


class MultipleChoiceTask(Task):
    def doc_to_target(self, doc):
        return " " + doc["choices"][doc["gold"]]
Leo Gao's avatar
Leo Gao committed
742

Leo Gao's avatar
Leo Gao committed
743
744
    def construct_requests(self, doc, ctx):
        lls = [
cjlovering's avatar
cjlovering committed
745
            rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
Leo Gao's avatar
Leo Gao committed
746
747
748
749
750
751
752
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

cjlovering's avatar
cjlovering committed
753
        acc = 1.0 if np.argmax(results) == gold else 0.0
754
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
cjlovering's avatar
cjlovering committed
755
        acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
Leo Gao's avatar
Leo Gao committed
756
757

        return {
Leo Gao's avatar
Leo Gao committed
758
759
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
760
        }
cjlovering's avatar
cjlovering committed
761

Leo Gao's avatar
Leo Gao committed
762
763
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
764
765
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
766
        }
cjlovering's avatar
cjlovering committed
767

Leo Gao's avatar
Leo Gao committed
768
769
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
770
771
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
772
773
774
        }


Jason Phang's avatar
Jason Phang committed
775
776
777
778
779
780
781
782
class PerplexityTask(Task, abc.ABC):
    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

cjlovering's avatar
cjlovering committed
783
784
785
786
787
788
789
790
791
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        assert (
            num_fewshot == 0
        ), "The number of fewshot examples must be 0 for perplexity tasks."
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
792
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
793
            "The `provide_description` arg will be removed in future versions. To prepend "
Jonathan Tow's avatar
Jonathan Tow committed
794
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
795
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
796
        )
797
798
        if provide_description is not None:
            # nudge people to not specify it at all
cjlovering's avatar
cjlovering committed
799
800
801
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
802

Jason Phang's avatar
Jason Phang committed
803
804
805
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
806
807
808
809
810
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
811
812

    def doc_to_text(self, doc):
813
        return ""
Jason Phang's avatar
Jason Phang committed
814
815

    def doc_to_target(self, doc):
816
        return doc
Jason Phang's avatar
Jason Phang committed
817
818
819

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
820
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
821
822
823
        return req

    def process_results(self, doc, results):
cjlovering's avatar
cjlovering committed
824
        (loglikelihood,) = results
Leo Gao's avatar
Leo Gao committed
825
        words = self.count_words(doc)
826
        bytes_ = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
827
        return {
Leo Gao's avatar
Leo Gao committed
828
            "word_perplexity": (loglikelihood, words),
829
            "byte_perplexity": (loglikelihood, bytes_),
830
            "bits_per_byte": (loglikelihood, bytes_),
Jason Phang's avatar
Jason Phang committed
831
832
833
834
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
835
836
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
837
            "bits_per_byte": bits_per_byte,
Jason Phang's avatar
Jason Phang committed
838
839
        }

840
841
    @classmethod
    def count_bytes(cls, doc):
Leo Gao's avatar
Leo Gao committed
842
        return len(doc.encode("utf-8"))
843
844
845

    @classmethod
    def count_words(cls, doc):
cjlovering's avatar
cjlovering committed
846
        """Downstream tasks with custom word boundaries should override this!"""
Leo Gao's avatar
Leo Gao committed
847
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
848

Jason Phang's avatar
Jason Phang committed
849

Leo Gao's avatar
Leo Gao committed
850
851
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
cjlovering's avatar
cjlovering committed
852
    return hashlib.sha256(dat.encode("utf-8")).hexdigest()
Leo Gao's avatar
Leo Gao committed
853
854


Leo Gao's avatar
Leo Gao committed
855
856
class CacheHook:
    def __init__(self, cachinglm):
cjlovering's avatar
cjlovering committed
857
        if cachinglm is None:
Leo Gao's avatar
Leo Gao committed
858
859
860
861
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
cjlovering's avatar
cjlovering committed
862

Leo Gao's avatar
Leo Gao committed
863
864
865
866
867
868
869
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
870
871
class CachingLM:
    def __init__(self, lm, cache_db):
872
873
874
875
876
877
878
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
Leo Gao's avatar
Leo Gao committed
879
880
        self.lm = lm
        self.cache_db = cache_db
881
882
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
883
884
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
885
886
887
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
888
889
890
891
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
cjlovering's avatar
cjlovering committed
892

Leo Gao's avatar
Leo Gao committed
893
894
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
895
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
896
897
898
899
900
901
902
903
904
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
cjlovering's avatar
cjlovering committed
905

906
            # actually run the LM on the requests that do not have cached results
Leo Gao's avatar
Leo Gao committed
907
908
909
910
911
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
912
913
                while res[resptr] is not None:
                    resptr += 1
Leo Gao's avatar
Leo Gao committed
914
915
916
917

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
918
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
919
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
920
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
921
922

            return res
cjlovering's avatar
cjlovering committed
923

Leo Gao's avatar
Leo Gao committed
924
        return fn
cjlovering's avatar
cjlovering committed
925

Leo Gao's avatar
Leo Gao committed
926
927
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
928

Jason Phang's avatar
Jason Phang committed
929

930
REQUEST_RETURN_LENGTHS = {
cjlovering's avatar
cjlovering committed
931
932
933
    "loglikelihood": 2,
    "greedy_until": None,
    "loglikelihood_rolling": None,
934
935
936
}


937
class Request:
Leo Gao's avatar
Leo Gao committed
938
939
    def __init__(self, request_type, args, index=None):
        if request_type not in REQUEST_RETURN_LENGTHS.keys():
cjlovering's avatar
cjlovering committed
940
941
942
            raise NotImplementedError(
                "The request type {} is not implemented!".format(request_type)
            )
Leo Gao's avatar
Leo Gao committed
943

Leo Gao's avatar
Leo Gao committed
944
        self.request_type = request_type
945
946
        self.args = args
        self.index = index
cjlovering's avatar
cjlovering committed
947

948
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
949
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
cjlovering's avatar
cjlovering committed
950
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
951
952
        for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
            yield Request(self.request_type, self.args, i)
cjlovering's avatar
cjlovering committed
953

954
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
955
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
cjlovering's avatar
cjlovering committed
956
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
957
        return Request(self.request_type, self.args, i)
cjlovering's avatar
cjlovering committed
958

Leo Gao's avatar
Leo Gao committed
959
    def __eq__(self, other):
cjlovering's avatar
cjlovering committed
960
961
962
963
964
        return (
            self.request_type == other.request_type
            and self.args == other.args
            and self.index == other.index
        )
Leo Gao's avatar
Leo Gao committed
965

Leo Gao's avatar
Leo Gao committed
966
    def __repr__(self):
Leo Gao's avatar
Leo Gao committed
967
        return f"Req_{self.request_type}{self.args}[{self.index}]\n"
968

Jason Phang's avatar
Jason Phang committed
969

Leo Gao's avatar
Leo Gao committed
970
971
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
972
973
        def fn(*args):
            return Request(attr, args)
cjlovering's avatar
cjlovering committed
974

Leo Gao's avatar
Leo Gao committed
975
976
977
978
        return fn


rf = RequestFactory()