base.py 36.8 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import abc
2
from typing import Iterable, Optional
cjlovering's avatar
cjlovering committed
3

4
import promptsource
thefazzer's avatar
thefazzer committed
5
import numpy as np
6
import random
Leo Gao's avatar
Leo Gao committed
7
import re
8
9
10
import os
import json
import hashlib
Jonathan Tow's avatar
Jonathan Tow committed
11
import datasets
12
from sqlitedict import SqliteDict
13
from tqdm import tqdm
14
import torch
Leo Gao's avatar
Leo Gao committed
15
import torch.nn.functional as F
&'s avatar
& committed
16

17
from lm_eval import metrics
18
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
19
from lm_eval import utils
20
from abc import abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
21

Jason Phang's avatar
Jason Phang committed
22

Leo Gao's avatar
Leo Gao committed
23
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
24
25
26
    def __init__(self):
        self.cache_hook = CacheHook(None)

27
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
28
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
29
        """Compute log-likelihood of generating a continuation from a context.
cjlovering's avatar
cjlovering committed
30
        Downstream tasks should attempt to use loglikelihood instead of other
Leo Gao's avatar
Leo Gao committed
31
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
32

Leo Gao's avatar
Leo Gao committed
33
34
35
        :param requests: list
            A list of pairs (context, continuation)
            context: str
cjlovering's avatar
cjlovering committed
36
                Context string. Implementations of LM must be able to handle an
Leo Gao's avatar
Leo Gao committed
37
                empty context string.
Leo Gao's avatar
Leo Gao committed
38
            continuation: str
cjlovering's avatar
cjlovering committed
39
40
                The continuation over which log likelihood will be calculated. If
                there is a word boundary, the space should be in the continuation.
Leo Gao's avatar
Leo Gao committed
41
42
43
44
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
45
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
46
            isgreedy:
Jason Phang's avatar
Jason Phang committed
47
48
49
50
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

51
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
52
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
90
91
92
        """
        pass

&'s avatar
& committed
93
    # TODO: Add an optional max length
94
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
95
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
96
97
98
99
100
101
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
102
            until: [str]
cjlovering's avatar
cjlovering committed
103
                The string sequences to generate until. These string sequences
Leo Gao's avatar
Leo Gao committed
104
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
105
106
107
108
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
109
        """
Leo Gao's avatar
Leo Gao committed
110
111
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
112
    @classmethod
113
114
    def create_from_arg_string(cls, arg_string, additional_config=None):
        additional_config = {} if additional_config is None else additional_config
115
116
117
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
118

Leo Gao's avatar
Leo Gao committed
119
120
121
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
122

123
class BaseLM(LM):
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    @property
    @abstractmethod
    def eot_token_id(self):
        pass

    @property
    @abstractmethod
    def max_length(self):
        pass

    @property
    @abstractmethod
    def max_gen_toks(self):
        pass

    @property
    @abstractmethod
    def batch_size(self):
        pass

    @property
    @abstractmethod
    def device(self):
        pass

149
    @abstractmethod
cjlovering's avatar
cjlovering committed
150
151
152
    def tok_encode(self, string: str):
        pass

153
    @abstractmethod
cjlovering's avatar
cjlovering committed
154
155
    def tok_decode(self, tokens: Iterable[int]):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
156

157
    @abstractmethod
cjlovering's avatar
cjlovering committed
158
159
    def _model_generate(self, context, max_length, eos_token_id):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
160

161
162
    @abstractmethod
    def _model_call(self, inps):
Jason Phang's avatar
gpt3  
Jason Phang committed
163
        """
164
165
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Jason Phang's avatar
gpt3  
Jason Phang committed
166

167
        returns: a torch tensor of shape [batch, sequence, vocab] with the
168
        logits returned from the model
169
170
        """
        pass
171

Leo Gao's avatar
Leo Gao committed
172
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    # TODO: enforce this somehow

    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
cjlovering's avatar
cjlovering committed
195
196
197
198
199
200
201
202
203
204
205
206
        for (string,) in tqdm(requests):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
207
208
209

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

210
211
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
cjlovering's avatar
cjlovering committed
212
213
214
215
            string_nll = self._loglikelihood_tokens(
                rolling_token_windows, disable_tqdm=True
            )

216
217
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
cjlovering's avatar
cjlovering committed
218

219
220
221
222
223
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

224
225
226
227
228
229
230
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
231
232
233
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
234
235
236
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
237
            return -len(toks), tuple(toks)
cjlovering's avatar
cjlovering committed
238

239
240
        # TODO: automatic (variable) batch size detection for vectorization
        reord = utils.Reorderer(requests, _collate)
cjlovering's avatar
cjlovering committed
241
242
243
        for chunk in utils.chunks(
            tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size
        ):
244
            inps = []
245
            cont_toks_list = []
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
262
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
263
                # gpt2    \               \
264
265
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice
266
267
268

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
cjlovering's avatar
cjlovering committed
269
270
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
271
                ).to(self.device)
cjlovering's avatar
cjlovering committed
272
                (inplen,) = inp.shape
273
274
275
276

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
cjlovering's avatar
cjlovering committed
277
278
279
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )
280

281
                # pad length from seq to padding_length
cjlovering's avatar
cjlovering committed
282
283
284
285
286
287
288
289
290
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )
291

292
293
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
294
295
                inplens.append(inplen)

296
            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
cjlovering's avatar
cjlovering committed
297
298
299
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]
300

cjlovering's avatar
cjlovering committed
301
302
303
            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
            ):
304

305
306
                # Slice to original seq length
                contlen = len(cont_toks)
cjlovering's avatar
cjlovering committed
307
308
309
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]
310

311
                # Check if per-token argmax is exactly equal to continuation
312
                greedy_tokens = logits.argmax(dim=-1)
cjlovering's avatar
cjlovering committed
313
314
315
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
316
317
                max_equal = (greedy_tokens == cont_toks).all()

318
319
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
cjlovering's avatar
cjlovering committed
320
321
322
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]
323

324
                # Answer: (log prob, is-exact-match)
325
326
327
328
329
330
331
332
333
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

        return reord.get_original(res)
cjlovering's avatar
cjlovering committed
334

335
    def greedy_until(self, requests):
cjlovering's avatar
cjlovering committed
336
        # TODO: implement fully general `until` that handles untils that are
337
        #       multiple tokens or that span multiple tokens correctly
338
339
340
341
342
343

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
344
            return len(toks), x[0]
cjlovering's avatar
cjlovering committed
345

346
347
        reord = utils.Reorderer(requests, _collate)

348
349
350
        for context, request_args in tqdm(reord.get_reordered()):
            stopping_criteria = request_args["stopping_criteria"]
            max_generation_length = request_args["max_generation_length"]
351

352
353
354
355
356
357
            assert isinstance(stopping_criteria, str) or stopping_criteria is None
            assert (
                isinstance(max_generation_length, int) or max_generation_length is None
            )

            until = [stopping_criteria]
cjlovering's avatar
cjlovering committed
358
            primary_until = self.tok_encode(until[0])
cjlovering's avatar
cjlovering committed
359
360
361
            context_enc = torch.tensor(
                [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
            ).to(self.device)
362

363
            if max_generation_length is None:
364
                max_length = context_enc.shape[1] + self.max_gen_toks
365
366
367
368
            else:
                max_length = min(
                    max_generation_length, context_enc.shape[1] + self.max_gen_toks
                )
cjlovering's avatar
cjlovering committed
369
            cont = self._model_generate(
cjlovering's avatar
cjlovering committed
370
                context_enc,
371
                max_length,
cjlovering's avatar
cjlovering committed
372
                torch.tensor(primary_until),
cjlovering's avatar
cjlovering committed
373
            )
374

cjlovering's avatar
cjlovering committed
375
            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
376
377
378

            for term in until:
                s = s.split(term)[0]
cjlovering's avatar
cjlovering committed
379

380
381
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
cjlovering's avatar
cjlovering committed
382

383
            res.append(s)
cjlovering's avatar
cjlovering committed
384

385
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
386

Leo Gao's avatar
Leo Gao committed
387

388
class Task(abc.ABC):
&'s avatar
&amp; committed
389
390
391
392
393
394
395
396
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Jonathan Tow's avatar
Jonathan Tow committed
397

Jon Tow's avatar
Jon Tow committed
398
399
    # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
    # or a path to a custom `datasets` loading script.
Jonathan Tow's avatar
Jonathan Tow committed
400
401
402
403
404
    DATASET_PATH: str = None

    # The name of a subset within `DATASET_PATH`.
    DATASET_NAME: str = None

Jon Tow's avatar
Jon Tow committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        """
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
        """
        self.download(data_dir, cache_dir, download_mode)
429
        self._training_docs = None
430
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
431

Jon Tow's avatar
Jon Tow committed
432
    def download(self, data_dir=None, cache_dir=None, download_mode=None):
cjlovering's avatar
cjlovering committed
433
        """Downloads and returns the task dataset.
Jonathan Tow's avatar
Jonathan Tow committed
434
435
        Override this method to download the dataset from a custom API.

Jon Tow's avatar
Jon Tow committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
Jonathan Tow's avatar
Jonathan Tow committed
456
457
458
459
        """
        self.dataset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
Jon Tow's avatar
Jon Tow committed
460
461
            data_dir=data_dir,
            cache_dir=cache_dir,
cjlovering's avatar
cjlovering committed
462
            download_mode=download_mode,
Jonathan Tow's avatar
Jonathan Tow committed
463
        )
sdtblck's avatar
sdtblck committed
464

465
    @abstractmethod
466
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
467
        """Whether the task has a training set"""
468
        pass
469

470
    @abstractmethod
471
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
472
473
474
        """Whether the task has a validation set"""
        pass

475
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
476
477
    def has_test_docs(self):
        """Whether the task has a test set"""
478
479
        pass

Leo Gao's avatar
Leo Gao committed
480
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
481
482
483
484
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
485
        return []
486

Leo Gao's avatar
Leo Gao committed
487
    def validation_docs(self):
488
489
490
491
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
492
        return []
493

Leo Gao's avatar
Leo Gao committed
494
    def test_docs(self):
495
496
497
498
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
499
        return []
Leo Gao's avatar
Leo Gao committed
500

Jon Tow's avatar
Jon Tow committed
501
502
503
504
    def _process_doc(self, doc):
        """
        Override this to process (detokenize, strip, replace, etc.) individual
        documents. This can be used in a map over documents of a data split.
Jon Tow's avatar
Jon Tow committed
505
        E.g. `map(self._process_doc, self.dataset["validation"])`
Jon Tow's avatar
Jon Tow committed
506
507
508
509
510
511

        :return: dict
            The processed version of the specified `doc`.
        """
        return doc

512
    def fewshot_examples(self, k, rnd):
513
514
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
515

Leo Gao's avatar
Leo Gao committed
516
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
517

518
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
519
520
521
    def doc_to_text(self, doc):
        pass

522
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
523
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
524
        pass
Leo Gao's avatar
Leo Gao committed
525

526
    @abstractmethod
527
    def construct_requests(self, doc, ctx):
cjlovering's avatar
cjlovering committed
528
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
529
530
        Requests which will be sent to the LM.

531
532
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
533
        :param ctx: str
cjlovering's avatar
cjlovering committed
534
            The context string, generated by fewshot_context. This includes the natural
535
            language description, as well as the few shot examples, and the question
cjlovering's avatar
cjlovering committed
536
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
537
        """
Leo Gao's avatar
Leo Gao committed
538
        pass
539

540
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
541
    def process_results(self, doc, results):
cjlovering's avatar
cjlovering committed
542
543
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
544
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
545
546
547
548
549

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
550
        """
Leo Gao's avatar
Leo Gao committed
551
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
552

553
    @abstractmethod
554
555
    def aggregation(self):
        """
&'s avatar
&amp; committed
556
        :returns: {str: [metric_score] -> float}
cjlovering's avatar
cjlovering committed
557
            A dictionary where keys are the names of submetrics and values are
&'s avatar
&amp; committed
558
            functions that aggregate a list of metric scores
559
560
561
        """
        pass

562
    @abstractmethod
563
564
565
    def higher_is_better(self):
        """
        :returns: {str: bool}
cjlovering's avatar
cjlovering committed
566
            A dictionary where keys are the names of submetrics and values are
567
568
569
570
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
571
    def fewshot_description(self):
572
        import warnings
cjlovering's avatar
cjlovering committed
573

574
        warnings.warn(
Jonathan Tow's avatar
Jonathan Tow committed
575
            "`fewshot_description` will be removed in futures versions. Pass "
576
            "any custom descriptions to the `evaluate` function instead.",
cjlovering's avatar
cjlovering committed
577
578
            DeprecationWarning,
        )
Jason Phang's avatar
checkin  
Jason Phang committed
579
580
        return ""

581
    @utils.positional_deprecated
cjlovering's avatar
cjlovering committed
582
583
584
585
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        """Returns a fewshot context string that is made up of a prepended description
586
587
588
589
590
591
592
593
594
595
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
596
            WARNING: This is currently a required arg although it's optionalized with a default `None`.
597
598
599
600
601
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
cjlovering's avatar
cjlovering committed
602
603
604
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
605
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
606
            "The `provide_description` arg will be removed in future versions. To prepend "
607
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
608
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
609
        )
610
611
        if provide_description is not None:
            # nudge people to not specify it at all
cjlovering's avatar
cjlovering committed
612
613
614
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
615

616
        description = description + "\n\n" if description else ""
617

618
619
        if num_fewshot == 0:
            labeled_examples = ""
620
        else:
621
622
623
624
625
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
626
                    self._fewshot_docs = list(
cjlovering's avatar
cjlovering committed
627
628
629
                        self.validation_docs()
                        if self.has_validation_docs()
                        else self.test_docs()
630
                    )
631

632
633
634
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
635

cjlovering's avatar
cjlovering committed
636
637
638
639
640
641
642
643
644
            labeled_examples = (
                "\n\n".join(
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex
                    ]
                )
                + "\n\n"
            )
Leo Gao's avatar
Update  
Leo Gao committed
645

646
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
647
648
649
        return description + labeled_examples + example


cjlovering's avatar
cjlovering committed
650
class PromptSourceTask(Task):
651
652
653
654
655
656
657
658
659
660
    """These are the metrics from promptsource that we have
    added default behavior for. If you want to add default behavior for a new metric,
    update the functions below. If you want to use one of the following metrics,
    *and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.

    WARNING: ROUGE is WIP.
    """

    CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"])

cjlovering's avatar
cjlovering committed
661
662
663
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None):
        super().__init__(data_dir, cache_dir, download_mode)
        self.prompt = prompt
Jon Tow's avatar
Jon Tow committed
664

665
    def stopping_criteria(self) -> Optional[str]:
cjlovering's avatar
cjlovering committed
666
        """Denote where the generation should end.
cjlovering's avatar
cjlovering committed
667
668

        For example, for coqa, this is '\nQ:' and for drop '.'.
cjlovering's avatar
cjlovering committed
669
670

        By default, its None, meaning to generate up to max or EOT, whichever comes first.
cjlovering's avatar
cjlovering committed
671
672
        """
        return None
jon-tow's avatar
jon-tow committed
673

674
675
676
677
    def max_generation_length(self) -> Optional[int]:
        """Denote where the max length of the generation if it is obvious from the task."""
        return None

cjlovering's avatar
cjlovering committed
678
679
    def invalid_doc_for_prompt(self, doc) -> bool:
        """Some prompts may not work for some documents."""
cjlovering's avatar
cjlovering committed
680
681
        if (
            # generate_paraphrase for mrpc
cjlovering's avatar
cjlovering committed
682
683
684
            # This generation prompt assumes a positive example. We filter out the negative examples.
            # https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L7
            # https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L88
cjlovering's avatar
cjlovering committed
685
686
687
688
689
690
691
692
693
            (
                self.prompt.id == "3b88d2c4-0aeb-4c6d-9ccc-653a388250a5"
                or self.prompt.id == "d830d7a5-abc0-4275-ac62-974e0088876f"
            )
            and doc["label"] == 0
        ):
            return True
        return False

cjlovering's avatar
cjlovering committed
694
695
    def doc_to_target(self, doc) -> str:
        """NOTE: In the future, this may return Union[str, List[str]]."""
jon-tow's avatar
jon-tow committed
696
        _, target = self.prompt.apply(doc)
cjlovering's avatar
cjlovering committed
697
698
        return f" {target}"

cjlovering's avatar
cjlovering committed
699
    def doc_to_text(self, doc) -> str:
jon-tow's avatar
jon-tow committed
700
        text, _ = self.prompt.apply(doc)
cjlovering's avatar
cjlovering committed
701
702
703
704
705
706
707
708
709
710
711
712
713
714
        return text

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
        """
        _requests = []
715
        answer_choices_list = self.prompt.get_answer_choices_list(doc)
716

717
        if answer_choices_list:
718
            # If answer_choices_list, then this is a ranked choice prompt.
719
            for answer_choice in answer_choices_list:
cjlovering's avatar
cjlovering committed
720
721
722
                ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
                _requests.append(ll_answer_choice)
        else:
723
724
            # If not, then this is a generation prompt.
            # NOTE: In the future, target will be a list of strings.
725
726
727
728
729
            request_args = {
                "stopping_criteria": self.stopping_criteria(),
                "max_generation_length": self.max_generation_length(),
            }
            cont_request = rf.greedy_until(ctx, request_args)
jon-tow's avatar
jon-tow committed
730
            _requests.append(cont_request)
cjlovering's avatar
cjlovering committed
731
732
733
734
735
736
737
738
739
740
741
742
743

        return _requests

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
744
745
746
        target = self.doc_to_target(doc).strip()
        answer_choices_list = self.prompt.get_answer_choices_list(doc)
        if answer_choices_list:
747
748
749
750
751
            # If answer_choices_list, then this is a ranked choice prompt.
            # NOTE: In the future, target will be a list of strings.
            # For now, we can assume there will be only 1 target, but its possible
            # that this not the case so we should check for that.

752
            pred = answer_choices_list[np.argmax(results)]
753
            out = {}
754
755
756
757
758
759
760

            for metric in self.prompt.metadata.metrics:
                assert (
                    metric in self.CONFIGURED_PS_METRICS
                ), "Unexpected metric. Add it, or use a task-specific solution."
                if metric == "Accuracy":
                    out["acc"] = pred == target
761
762
            # TODO: Add metrics here.
            return out
cjlovering's avatar
cjlovering committed
763
        else:
764
765
            # If not, then this is a generation prompt.
            # NOTE: In the future, target will be a list of strings.
766
767
            pred = results[0].strip()
            out = {}
cjlovering's avatar
cjlovering committed
768

769
770
771
772
773
774
775
776
777
778
            for metric in self.prompt.metadata.metrics:
                assert (
                    metric in self.CONFIGURED_PS_METRICS
                ), "Unexpected metric. Add it, or use a task-specific solution."
                if metric == "BLEU":
                    out["bleu"] = (target, pred)
                if metric == "ROUGE":
                    print("WARNING: Skipping Rouge.")

            return out
779

780
    def higher_is_better(self):
781
        out = {}
782
783
784
785
786
787
788
789
790
791
        for metric in self.prompt.metadata.metrics:
            assert (
                metric in self.CONFIGURED_PS_METRICS
            ), "Unexpected metric. Add it, or use a task-specific solution."
            if metric == "Accuracy":
                out["acc"] = True
            if metric == "BLEU":
                out["bleu"] = True
            if metric == "ROUGE":
                print("WARNING: Skipping Rouge.")
792
        return out
793
794

    def aggregation(self):
795
        out = {}
796
797
798
799
800
801
802
803
804
805
        for metric in self.prompt.metadata.metrics:
            assert (
                metric in self.CONFIGURED_PS_METRICS
            ), "Unexpected metric. Add it, or use a task-specific solution."
            if metric == "Accuracy":
                out["acc"] = mean
            if metric == "BLEU":
                out["bleu"] = metrics.bleu
            if metric == "ROUGE":
                print("WARNING: Skipping Rouge.")
806
        return out
cjlovering's avatar
cjlovering committed
807
808
809
810
811


class MultipleChoiceTask(Task):
    def doc_to_target(self, doc):
        return " " + doc["choices"][doc["gold"]]
Leo Gao's avatar
Leo Gao committed
812

Leo Gao's avatar
Leo Gao committed
813
814
    def construct_requests(self, doc, ctx):
        lls = [
cjlovering's avatar
cjlovering committed
815
            rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
Leo Gao's avatar
Leo Gao committed
816
817
818
819
820
821
822
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

cjlovering's avatar
cjlovering committed
823
        acc = 1.0 if np.argmax(results) == gold else 0.0
824
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
cjlovering's avatar
cjlovering committed
825
        acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
Leo Gao's avatar
Leo Gao committed
826
827

        return {
Leo Gao's avatar
Leo Gao committed
828
829
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
830
        }
cjlovering's avatar
cjlovering committed
831

Leo Gao's avatar
Leo Gao committed
832
833
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
834
835
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
836
        }
cjlovering's avatar
cjlovering committed
837

Leo Gao's avatar
Leo Gao committed
838
839
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
840
841
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
842
843
844
        }


Jason Phang's avatar
Jason Phang committed
845
846
847
848
849
850
851
852
class PerplexityTask(Task, abc.ABC):
    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

cjlovering's avatar
cjlovering committed
853
854
855
856
857
858
859
860
861
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        assert (
            num_fewshot == 0
        ), "The number of fewshot examples must be 0 for perplexity tasks."
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
862
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
863
            "The `provide_description` arg will be removed in future versions. To prepend "
Jonathan Tow's avatar
Jonathan Tow committed
864
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
865
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
866
        )
867
868
        if provide_description is not None:
            # nudge people to not specify it at all
cjlovering's avatar
cjlovering committed
869
870
871
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
872

Jason Phang's avatar
Jason Phang committed
873
874
875
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
876
877
878
879
880
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
881
882

    def doc_to_text(self, doc):
883
        return ""
Jason Phang's avatar
Jason Phang committed
884
885

    def doc_to_target(self, doc):
886
        return doc
Jason Phang's avatar
Jason Phang committed
887
888
889

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
890
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
891
892
893
        return req

    def process_results(self, doc, results):
cjlovering's avatar
cjlovering committed
894
        (loglikelihood,) = results
Leo Gao's avatar
Leo Gao committed
895
        words = self.count_words(doc)
896
        bytes_ = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
897
        return {
Leo Gao's avatar
Leo Gao committed
898
            "word_perplexity": (loglikelihood, words),
899
            "byte_perplexity": (loglikelihood, bytes_),
900
            "bits_per_byte": (loglikelihood, bytes_),
Jason Phang's avatar
Jason Phang committed
901
902
903
904
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
905
906
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
907
            "bits_per_byte": bits_per_byte,
Jason Phang's avatar
Jason Phang committed
908
909
        }

910
911
    @classmethod
    def count_bytes(cls, doc):
Leo Gao's avatar
Leo Gao committed
912
        return len(doc.encode("utf-8"))
913
914
915

    @classmethod
    def count_words(cls, doc):
cjlovering's avatar
cjlovering committed
916
        """Downstream tasks with custom word boundaries should override this!"""
Leo Gao's avatar
Leo Gao committed
917
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
918

Jason Phang's avatar
Jason Phang committed
919

Leo Gao's avatar
Leo Gao committed
920
921
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
cjlovering's avatar
cjlovering committed
922
    return hashlib.sha256(dat.encode("utf-8")).hexdigest()
Leo Gao's avatar
Leo Gao committed
923
924


Leo Gao's avatar
Leo Gao committed
925
926
class CacheHook:
    def __init__(self, cachinglm):
cjlovering's avatar
cjlovering committed
927
        if cachinglm is None:
Leo Gao's avatar
Leo Gao committed
928
929
930
931
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
cjlovering's avatar
cjlovering committed
932

Leo Gao's avatar
Leo Gao committed
933
934
935
936
937
938
939
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
940
941
class CachingLM:
    def __init__(self, lm, cache_db):
942
943
944
945
946
947
948
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
Leo Gao's avatar
Leo Gao committed
949
950
        self.lm = lm
        self.cache_db = cache_db
951
952
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
953
954
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
955
956
957
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
958
959
960
961
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
cjlovering's avatar
cjlovering committed
962

Leo Gao's avatar
Leo Gao committed
963
964
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
965
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
966
967
968
969
970
971
972
973
974
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
cjlovering's avatar
cjlovering committed
975

976
            # actually run the LM on the requests that do not have cached results
Leo Gao's avatar
Leo Gao committed
977
978
979
980
981
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
982
983
                while res[resptr] is not None:
                    resptr += 1
Leo Gao's avatar
Leo Gao committed
984
985
986
987

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
988
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
989
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
990
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
991
992

            return res
cjlovering's avatar
cjlovering committed
993

Leo Gao's avatar
Leo Gao committed
994
        return fn
cjlovering's avatar
cjlovering committed
995

Leo Gao's avatar
Leo Gao committed
996
997
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
998

Jason Phang's avatar
Jason Phang committed
999

1000
REQUEST_RETURN_LENGTHS = {
cjlovering's avatar
cjlovering committed
1001
1002
1003
    "loglikelihood": 2,
    "greedy_until": None,
    "loglikelihood_rolling": None,
1004
1005
1006
}


1007
class Request:
Leo Gao's avatar
Leo Gao committed
1008
1009
    def __init__(self, request_type, args, index=None):
        if request_type not in REQUEST_RETURN_LENGTHS.keys():
cjlovering's avatar
cjlovering committed
1010
1011
1012
            raise NotImplementedError(
                "The request type {} is not implemented!".format(request_type)
            )
Leo Gao's avatar
Leo Gao committed
1013

Leo Gao's avatar
Leo Gao committed
1014
        self.request_type = request_type
1015
1016
        self.args = args
        self.index = index
cjlovering's avatar
cjlovering committed
1017

1018
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
1019
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
cjlovering's avatar
cjlovering committed
1020
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
1021
1022
        for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
            yield Request(self.request_type, self.args, i)
cjlovering's avatar
cjlovering committed
1023

1024
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
1025
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
cjlovering's avatar
cjlovering committed
1026
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
1027
        return Request(self.request_type, self.args, i)
cjlovering's avatar
cjlovering committed
1028

Leo Gao's avatar
Leo Gao committed
1029
    def __eq__(self, other):
cjlovering's avatar
cjlovering committed
1030
1031
1032
1033
1034
        return (
            self.request_type == other.request_type
            and self.args == other.args
            and self.index == other.index
        )
Leo Gao's avatar
Leo Gao committed
1035

Leo Gao's avatar
Leo Gao committed
1036
    def __repr__(self):
Leo Gao's avatar
Leo Gao committed
1037
        return f"Req_{self.request_type}{self.args}[{self.index}]\n"
1038

Jason Phang's avatar
Jason Phang committed
1039

Leo Gao's avatar
Leo Gao committed
1040
1041
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
1042
1043
        def fn(*args):
            return Request(attr, args)
cjlovering's avatar
cjlovering committed
1044

Leo Gao's avatar
Leo Gao committed
1045
1046
1047
1048
        return fn


rf = RequestFactory()