base.py 44.5 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import abc
2
from typing import Iterable, Optional
cjlovering's avatar
cjlovering committed
3

4
import promptsource
thefazzer's avatar
thefazzer committed
5
import numpy as np
6
import random
Leo Gao's avatar
Leo Gao committed
7
import re
8
9
10
import os
import json
import hashlib
Jonathan Tow's avatar
Jonathan Tow committed
11
import datasets
12
from sqlitedict import SqliteDict
13
from tqdm import tqdm
14
import torch
Leo Gao's avatar
Leo Gao committed
15
import torch.nn.functional as F
&'s avatar
& committed
16

17
from lm_eval import metrics
18
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
19
from lm_eval import utils
20
from abc import abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
21

Jason Phang's avatar
Jason Phang committed
22

Leo Gao's avatar
Leo Gao committed
23
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
24
25
26
    def __init__(self):
        self.cache_hook = CacheHook(None)

27
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
28
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
29
        """Compute log-likelihood of generating a continuation from a context.
cjlovering's avatar
cjlovering committed
30
        Downstream tasks should attempt to use loglikelihood instead of other
Leo Gao's avatar
Leo Gao committed
31
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
32

Leo Gao's avatar
Leo Gao committed
33
34
35
        :param requests: list
            A list of pairs (context, continuation)
            context: str
cjlovering's avatar
cjlovering committed
36
                Context string. Implementations of LM must be able to handle an
Leo Gao's avatar
Leo Gao committed
37
                empty context string.
Leo Gao's avatar
Leo Gao committed
38
            continuation: str
cjlovering's avatar
cjlovering committed
39
40
                The continuation over which log likelihood will be calculated. If
                there is a word boundary, the space should be in the continuation.
Leo Gao's avatar
Leo Gao committed
41
42
43
44
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
45
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
46
            isgreedy:
Jason Phang's avatar
Jason Phang committed
47
48
49
50
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

51
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
52
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
90
91
92
        """
        pass

&'s avatar
& committed
93
    # TODO: Add an optional max length
94
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
95
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
96
97
98
99
100
101
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
102
            until: [str]
cjlovering's avatar
cjlovering committed
103
                The string sequences to generate until. These string sequences
Leo Gao's avatar
Leo Gao committed
104
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
105
106
107
108
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
109
        """
Leo Gao's avatar
Leo Gao committed
110
111
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
112
    @classmethod
113
114
    def create_from_arg_string(cls, arg_string, additional_config=None):
        additional_config = {} if additional_config is None else additional_config
115
116
117
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
118

Leo Gao's avatar
Leo Gao committed
119
120
121
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
122

123
class BaseLM(LM):
Tian Yun's avatar
Tian Yun committed
124
125
126
127
128
    @property
    @abstractmethod
    def eot_token(self):
        pass

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    @property
    @abstractmethod
    def eot_token_id(self):
        pass

    @property
    @abstractmethod
    def max_length(self):
        pass

    @property
    @abstractmethod
    def max_gen_toks(self):
        pass

    @property
    @abstractmethod
    def batch_size(self):
        pass

    @property
    @abstractmethod
    def device(self):
        pass

154
    @abstractmethod
cjlovering's avatar
cjlovering committed
155
156
157
    def tok_encode(self, string: str):
        pass

158
    @abstractmethod
cjlovering's avatar
cjlovering committed
159
160
    def tok_decode(self, tokens: Iterable[int]):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
161

162
    @abstractmethod
cjlovering's avatar
cjlovering committed
163
164
    def _model_generate(self, context, max_length, eos_token_id):
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
165

166
167
    @abstractmethod
    def _model_call(self, inps):
Jason Phang's avatar
gpt3  
Jason Phang committed
168
        """
169
170
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Jason Phang's avatar
gpt3  
Jason Phang committed
171

172
        returns: a torch tensor of shape [batch, sequence, vocab] with the
173
        logits returned from the model
174
175
        """
        pass
176

Leo Gao's avatar
Leo Gao committed
177
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    # TODO: enforce this somehow

    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
cjlovering's avatar
cjlovering committed
200
201
202
203
204
205
206
207
208
209
210
211
        for (string,) in tqdm(requests):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )
212
213
214

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

215
216
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
cjlovering's avatar
cjlovering committed
217
218
219
220
            string_nll = self._loglikelihood_tokens(
                rolling_token_windows, disable_tqdm=True
            )

221
222
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
cjlovering's avatar
cjlovering committed
223

224
225
226
227
228
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

229
230
231
232
233
234
235
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
236
237
238
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
239
240
241
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
242
            return -len(toks), tuple(toks)
cjlovering's avatar
cjlovering committed
243

244
245
        # TODO: automatic (variable) batch size detection for vectorization
        reord = utils.Reorderer(requests, _collate)
cjlovering's avatar
cjlovering committed
246
247
248
        for chunk in utils.chunks(
            tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size
        ):
249
            inps = []
250
            cont_toks_list = []
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
267
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
268
                # gpt2    \               \
269
270
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice
271
272
273

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
cjlovering's avatar
cjlovering committed
274
275
                    (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                    dtype=torch.long,
276
                ).to(self.device)
cjlovering's avatar
cjlovering committed
277
                (inplen,) = inp.shape
278
279
280
281

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
cjlovering's avatar
cjlovering committed
282
283
284
                padding_length = (
                    padding_length if padding_length is not None else inplen
                )
285

286
                # pad length from seq to padding_length
cjlovering's avatar
cjlovering committed
287
288
289
290
291
292
293
294
295
                inp = torch.cat(
                    [
                        inp,  # [seq]
                        torch.zeros(padding_length - inplen, dtype=torch.long).to(
                            inp.device
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                )
296

297
298
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
299
300
                inplens.append(inplen)

301
            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
cjlovering's avatar
cjlovering committed
302
303
304
            multi_logits = F.log_softmax(
                self._model_call(batched_inps), dim=-1
            ).cpu()  # [batch, padding_length, vocab]
305

cjlovering's avatar
cjlovering committed
306
307
308
            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
                chunk, multi_logits, inps, inplens, cont_toks_list
            ):
309

310
311
                # Slice to original seq length
                contlen = len(cont_toks)
cjlovering's avatar
cjlovering committed
312
313
314
                logits = logits[inplen - contlen : inplen].unsqueeze(
                    0
                )  # [1, seq, vocab]
315

316
                # Check if per-token argmax is exactly equal to continuation
317
                greedy_tokens = logits.argmax(dim=-1)
cjlovering's avatar
cjlovering committed
318
319
320
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0
                )  # [1, seq]
321
322
                max_equal = (greedy_tokens == cont_toks).all()

323
324
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
cjlovering's avatar
cjlovering committed
325
326
327
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                    -1
                )  # [1, seq]
328

329
                # Answer: (log prob, is-exact-match)
330
331
332
333
334
335
336
337
338
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

        return reord.get_original(res)
cjlovering's avatar
cjlovering committed
339

340
    def greedy_until(self, requests):
cjlovering's avatar
cjlovering committed
341
        # TODO: implement fully general `until` that handles untils that are
342
        #       multiple tokens or that span multiple tokens correctly
343
344
345
346
347
348

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
349
            return len(toks), x[0]
cjlovering's avatar
cjlovering committed
350

351
352
        reord = utils.Reorderer(requests, _collate)

353
354
355
        for context, request_args in tqdm(reord.get_reordered()):
            stopping_criteria = request_args["stopping_criteria"]
            max_generation_length = request_args["max_generation_length"]
356
            num_fewshot = request_args["num_fewshot"]
357

358
359
360
361
            assert isinstance(stopping_criteria, str) or stopping_criteria is None
            assert (
                isinstance(max_generation_length, int) or max_generation_length is None
            )
362
            assert isinstance(num_fewshot, int) or num_fewshot is None
363

Tian Yun's avatar
Tian Yun committed
364
365
366
367
            if stopping_criteria is None:
                until = [self.eot_token] 
            else:
                until = [stopping_criteria]
cjlovering's avatar
cjlovering committed
368
            primary_until = self.tok_encode(until[0])
Tian Yun's avatar
Tian Yun committed
369
370
371
372

            if len(primary_until) == 0:
                primary_until = torch.tensor([self.eot_token_id])

cjlovering's avatar
cjlovering committed
373
374
375
            context_enc = torch.tensor(
                [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
            ).to(self.device)
376

377
            if max_generation_length is None:
378
                max_length = context_enc.shape[1] + self.max_gen_toks
379
380
381
382
            else:
                max_length = min(
                    max_generation_length, context_enc.shape[1] + self.max_gen_toks
                )
cjlovering's avatar
cjlovering committed
383
            cont = self._model_generate(
cjlovering's avatar
cjlovering committed
384
                context_enc,
385
                max_length,
cjlovering's avatar
cjlovering committed
386
                torch.tensor(primary_until),
387
                num_fewshot,
cjlovering's avatar
cjlovering committed
388
            )
389

390
            s = self.tok_decode(cont.tolist())
391
392
393

            for term in until:
                s = s.split(term)[0]
cjlovering's avatar
cjlovering committed
394

395
396
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
cjlovering's avatar
cjlovering committed
397

398
            res.append(s)
cjlovering's avatar
cjlovering committed
399

400
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
401

Leo Gao's avatar
Leo Gao committed
402

403
class Task(abc.ABC):
&'s avatar
&amp; committed
404
405
406
407
408
409
410
411
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Jonathan Tow's avatar
Jonathan Tow committed
412

Jon Tow's avatar
Jon Tow committed
413
414
    # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
    # or a path to a custom `datasets` loading script.
Jonathan Tow's avatar
Jonathan Tow committed
415
416
417
418
419
    DATASET_PATH: str = None

    # The name of a subset within `DATASET_PATH`.
    DATASET_NAME: str = None

Jon Tow's avatar
Jon Tow committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        """
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
        """
        self.download(data_dir, cache_dir, download_mode)
444
        self._training_docs = None
445
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
446

Jon Tow's avatar
Jon Tow committed
447
    def download(self, data_dir=None, cache_dir=None, download_mode=None):
cjlovering's avatar
cjlovering committed
448
        """Downloads and returns the task dataset.
Jonathan Tow's avatar
Jonathan Tow committed
449
450
        Override this method to download the dataset from a custom API.

Jon Tow's avatar
Jon Tow committed
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
Jonathan Tow's avatar
Jonathan Tow committed
471
472
473
474
        """
        self.dataset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
Jon Tow's avatar
Jon Tow committed
475
476
            data_dir=data_dir,
            cache_dir=cache_dir,
cjlovering's avatar
cjlovering committed
477
            download_mode=download_mode,
Jonathan Tow's avatar
Jonathan Tow committed
478
        )
sdtblck's avatar
sdtblck committed
479

480
    @abstractmethod
481
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
482
        """Whether the task has a training set"""
483
        pass
484

485
    @abstractmethod
486
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
487
488
489
        """Whether the task has a validation set"""
        pass

490
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
491
492
    def has_test_docs(self):
        """Whether the task has a test set"""
493
494
        pass

Leo Gao's avatar
Leo Gao committed
495
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
496
497
498
499
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
500
        return []
501

Leo Gao's avatar
Leo Gao committed
502
    def validation_docs(self):
503
504
505
506
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
507
        return []
508

Leo Gao's avatar
Leo Gao committed
509
    def test_docs(self):
510
511
512
513
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
514
        return []
Leo Gao's avatar
Leo Gao committed
515

Jon Tow's avatar
Jon Tow committed
516
517
518
519
    def _process_doc(self, doc):
        """
        Override this to process (detokenize, strip, replace, etc.) individual
        documents. This can be used in a map over documents of a data split.
Jon Tow's avatar
Jon Tow committed
520
        E.g. `map(self._process_doc, self.dataset["validation"])`
Jon Tow's avatar
Jon Tow committed
521
522
523
524
525
526

        :return: dict
            The processed version of the specified `doc`.
        """
        return doc

527
    def fewshot_examples(self, k, rnd):
528
529
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
530

Leo Gao's avatar
Leo Gao committed
531
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
532

533
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
534
535
536
    def doc_to_text(self, doc):
        pass

537
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
538
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
539
        pass
Leo Gao's avatar
Leo Gao committed
540

541
    @abstractmethod
542
    def construct_requests(self, doc, ctx, args):
cjlovering's avatar
cjlovering committed
543
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
544
545
        Requests which will be sent to the LM.

546
547
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
548
        :param ctx: str
cjlovering's avatar
cjlovering committed
549
            The context string, generated by fewshot_context. This includes the natural
550
            language description, as well as the few shot examples, and the question
cjlovering's avatar
cjlovering committed
551
            part of the document for `doc`.
552
553
        :param args: dict
            The specifics of the context, including number of few shots.
Leo Gao's avatar
Leo Gao committed
554
        """
Leo Gao's avatar
Leo Gao committed
555
        pass
556

557
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
558
    def process_results(self, doc, results):
cjlovering's avatar
cjlovering committed
559
560
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
561
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
562
563
564
565
566

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
567
        """
Leo Gao's avatar
Leo Gao committed
568
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
569

570
    @abstractmethod
571
572
    def aggregation(self):
        """
&'s avatar
&amp; committed
573
        :returns: {str: [metric_score] -> float}
cjlovering's avatar
cjlovering committed
574
            A dictionary where keys are the names of submetrics and values are
&'s avatar
&amp; committed
575
            functions that aggregate a list of metric scores
576
577
578
        """
        pass

579
    @abstractmethod
580
581
582
    def higher_is_better(self):
        """
        :returns: {str: bool}
cjlovering's avatar
cjlovering committed
583
            A dictionary where keys are the names of submetrics and values are
584
585
586
587
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
588
    def fewshot_description(self):
589
        import warnings
cjlovering's avatar
cjlovering committed
590

591
        warnings.warn(
Jonathan Tow's avatar
Jonathan Tow committed
592
            "`fewshot_description` will be removed in futures versions. Pass "
593
            "any custom descriptions to the `evaluate` function instead.",
cjlovering's avatar
cjlovering committed
594
595
            DeprecationWarning,
        )
Jason Phang's avatar
checkin  
Jason Phang committed
596
597
        return ""

598
    @utils.positional_deprecated
cjlovering's avatar
cjlovering committed
599
600
601
602
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        """Returns a fewshot context string that is made up of a prepended description
603
604
605
606
607
608
609
610
611
612
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
613
            WARNING: This is currently a required arg although it's optionalized with a default `None`.
614
615
616
617
618
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
cjlovering's avatar
cjlovering committed
619
620
621
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
622
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
623
            "The `provide_description` arg will be removed in future versions. To prepend "
624
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
625
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
626
        )
627
628
        if provide_description is not None:
            # nudge people to not specify it at all
cjlovering's avatar
cjlovering committed
629
630
631
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
632

633
        description = description + "\n\n" if description else ""
634

635
636
        if num_fewshot == 0:
            labeled_examples = ""
637
        else:
638
639
640
641
642
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
643
                    self._fewshot_docs = list(
cjlovering's avatar
cjlovering committed
644
645
646
                        self.validation_docs()
                        if self.has_validation_docs()
                        else self.test_docs()
647
                    )
648

649
650
651
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
652

jon-tow's avatar
jon-tow committed
653
654
655
656
            # See Webson & Pavlick (2022) https://arxiv.org/pdf/2109.01247.pdf
            # for justification of this separator.
            example_separator = "\n###\n"

cjlovering's avatar
cjlovering committed
657
            labeled_examples = (
jon-tow's avatar
jon-tow committed
658
                example_separator.join(
cjlovering's avatar
cjlovering committed
659
660
661
662
663
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex
                    ]
                )
jon-tow's avatar
jon-tow committed
664
                + example_separator
cjlovering's avatar
cjlovering committed
665
            )
Leo Gao's avatar
Update  
Leo Gao committed
666

667
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
668
669
670
        return description + labeled_examples + example


cjlovering's avatar
cjlovering committed
671
class PromptSourceTask(Task):
672
673
674
675
676
677
    """These are the metrics from promptsource that we have
    added default behavior for. If you want to add default behavior for a new metric,
    update the functions below. If you want to use one of the following metrics,
    *and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
    """

cjlovering's avatar
cjlovering committed
678
679
680
681
682
683
684
685
686
687
688
689
    CONFIGURED_RANKED_CHOICE_PS_METRICS = set(["Accuracy"])
    CONFIGURED_GENERATION_PS_METRICS = set(["BLEU", "ROUGE"])
    SPLIT = None

    def __init__(
        self,
        data_dir=None,
        cache_dir=None,
        download_mode=None,
        prompt=None,
        save_examples=True,
    ):
cjlovering's avatar
cjlovering committed
690
691
        super().__init__(data_dir, cache_dir, download_mode)
        self.prompt = prompt
cjlovering's avatar
cjlovering committed
692
        self.save_examples = save_examples
Jon Tow's avatar
Jon Tow committed
693

694
    def stopping_criteria(self) -> Optional[str]:
cjlovering's avatar
cjlovering committed
695
        """Denote where the generation should end.
cjlovering's avatar
cjlovering committed
696
697

        For example, for coqa, this is '\nQ:' and for drop '.'.
cjlovering's avatar
cjlovering committed
698
699

        By default, its None, meaning to generate up to max or EOT, whichever comes first.
cjlovering's avatar
cjlovering committed
700
701
        """
        return None
jon-tow's avatar
jon-tow committed
702

703
704
705
706
    def max_generation_length(self) -> Optional[int]:
        """Denote where the max length of the generation if it is obvious from the task."""
        return None

cjlovering's avatar
cjlovering committed
707
708
    def invalid_doc_for_prompt(self, doc) -> bool:
        """Some prompts may not work for some documents."""
cjlovering's avatar
cjlovering committed
709
710
        if (
            # generate_paraphrase for mrpc
cjlovering's avatar
cjlovering committed
711
712
713
            # This generation prompt assumes a positive example. We filter out the negative examples.
            # https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L7
            # https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L88
cjlovering's avatar
cjlovering committed
714
715
716
717
718
719
720
721
722
            (
                self.prompt.id == "3b88d2c4-0aeb-4c6d-9ccc-653a388250a5"
                or self.prompt.id == "d830d7a5-abc0-4275-ac62-974e0088876f"
            )
            and doc["label"] == 0
        ):
            return True
        return False

cjlovering's avatar
cjlovering committed
723
724
    def doc_to_target(self, doc) -> str:
        """NOTE: In the future, this may return Union[str, List[str]]."""
jon-tow's avatar
jon-tow committed
725
        _, target = self.prompt.apply(doc)
cjlovering's avatar
cjlovering committed
726
727
        return f" {target}"

cjlovering's avatar
cjlovering committed
728
    def doc_to_text(self, doc) -> str:
jon-tow's avatar
jon-tow committed
729
        text, _ = self.prompt.apply(doc)
cjlovering's avatar
cjlovering committed
730
731
        return text

732
    def construct_requests(self, doc, ctx, args):
cjlovering's avatar
cjlovering committed
733
734
735
736
737
738
739
740
741
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
742
743
        :param args: dict
            The specifics of the context, including number of few shots.
cjlovering's avatar
cjlovering committed
744
745
        """
        _requests = []
746
        answer_choices_list = self.prompt.get_answer_choices_list(doc)
747

748
        if answer_choices_list:
749
            # If answer_choices_list, then this is a ranked choice prompt.
750
            for answer_choice in answer_choices_list:
cjlovering's avatar
cjlovering committed
751
752
753
                ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
                _requests.append(ll_answer_choice)
        else:
754
755
            # If not, then this is a generation prompt.
            # NOTE: In the future, target will be a list of strings.
756
757
758
            request_args = {
                "stopping_criteria": self.stopping_criteria(),
                "max_generation_length": self.max_generation_length(),
759
                "num_fewshot": args["num_fewshot"],
760
761
            }
            cont_request = rf.greedy_until(ctx, request_args)
jon-tow's avatar
jon-tow committed
762
            _requests.append(cont_request)
cjlovering's avatar
cjlovering committed
763
764
765
766
767
768
769
770
771
772
773
774
775

        return _requests

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
776
777
778
        target = self.doc_to_target(doc).strip()
        answer_choices_list = self.prompt.get_answer_choices_list(doc)
        if answer_choices_list:
779
780
781
782
783
            # If answer_choices_list, then this is a ranked choice prompt.
            # NOTE: In the future, target will be a list of strings.
            # For now, we can assume there will be only 1 target, but its possible
            # that this not the case so we should check for that.

784
            pred = answer_choices_list[np.argmax(results)]
785
            out = {}
786
787
788

            for metric in self.prompt.metadata.metrics:
                assert (
cjlovering's avatar
cjlovering committed
789
                    metric in self.CONFIGURED_RANKED_CHOICE_PS_METRICS
790
791
792
                ), "Unexpected metric. Add it, or use a task-specific solution."
                if metric == "Accuracy":
                    out["acc"] = pred == target
793
            # TODO: Add metrics here.
cjlovering's avatar
cjlovering committed
794
        else:
795
796
            # If not, then this is a generation prompt.
            # NOTE: In the future, target will be a list of strings.
797
798
799
800
            pred = results[0].strip()
            out = {}
            for metric in self.prompt.metadata.metrics:
                assert (
cjlovering's avatar
cjlovering committed
801
                    metric in self.CONFIGURED_GENERATION_PS_METRICS
802
803
804
                ), "Unexpected metric. Add it, or use a task-specific solution."
                if metric == "BLEU":
                    out["bleu"] = (target, pred)
805
                elif metric == "ROUGE":
806
807
808
809
810
811
812
813
                    # TODO: This computes all rouge sub-metrics. Find a generic
                    # way to handle user specified rouge sub-metrics to avoid extra
                    # compute.
                    rouge_scores = metrics.rouge(target, pred)
                    # Flatten rouge score dict.
                    rouge_scores = utils.flatten(rouge_scores)
                    # Merge all the rouge-type scores into the `out` dict.
                    out = {**out, **rouge_scores}
cjlovering's avatar
cjlovering committed
814
815
816
817
818
819
820
821
822
823
824

        # TODO: Wrap process results s.t. override impl do not
        # override the save examples.
        if self.save_examples:
            example = {
                "pred": pred,
                "target": target,
                "answer_choices_list": answer_choices_list,
            }
            return out, example
        return out
825

826
    def higher_is_better(self):
827
        out = {}
828
829
830
831
832
833
        for metric in self.prompt.metadata.metrics:
            if metric == "Accuracy":
                out["acc"] = True
            if metric == "BLEU":
                out["bleu"] = True
            if metric == "ROUGE":
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
                # TODO: Find a generic way to handle user specified rouge metrics.
                out["rouge1_precision"] = True
                out["rouge1_recall"] = True
                out["rouge1_fmeasure"] = True

                out["rouge2_precision"] = True
                out["rouge2_recall"] = True
                out["rouge2_fmeasure"] = True

                out["rougeL_precision"] = True
                out["rougeL_recall"] = True
                out["rougeL_fmeasure"] = True

                out["rougeLsum_precision"] = True
                out["rougeLsum_recall"] = True
                out["rougeLsum_fmeasure"] = True
850
        return out
851
852

    def aggregation(self):
853
        out = {}
854
855
856
857
858
859
        for metric in self.prompt.metadata.metrics:
            if metric == "Accuracy":
                out["acc"] = mean
            if metric == "BLEU":
                out["bleu"] = metrics.bleu
            if metric == "ROUGE":
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
                # TODO: Find a generic way to handle user specified rouge metrics.
                out["rouge1_precision"] = mean
                out["rouge1_recall"] = mean
                out["rouge1_fmeasure"] = mean

                out["rouge2_precision"] = mean
                out["rouge2_recall"] = mean
                out["rouge2_fmeasure"] = mean

                out["rougeL_precision"] = mean
                out["rougeL_recall"] = mean
                out["rougeL_fmeasure"] = mean

                out["rougeLsum_precision"] = mean
                out["rougeLsum_recall"] = mean
                out["rougeLsum_fmeasure"] = mean
876
        return out
cjlovering's avatar
cjlovering committed
877

cjlovering's avatar
cjlovering committed
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
    def fewshot_examples(self, k, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
        return self._get_fewshot_examples(self._training_docs, k, rnd)

    def _get_fewshot_examples(self, docs, k, rnd):
        fewshot_idx = rnd.sample(list(np.arange(len(docs))), k)
        return [docs[idx] for idx in fewshot_idx], [int(idx) for idx in fewshot_idx]

    @utils.positional_deprecated
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        """Returns a fewshot context string that is made up of a prepended description
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
            WARNING: This is currently a required arg although it's optionalized with a default `None`.
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        assert not provide_description, (
            "The `provide_description` arg will be removed in future versions. To prepend "
            "a custom description to the context, supply the corresponding string via the "
            "`description` arg."
        )
        if provide_description is not None:
            # nudge people to not specify it at all
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )

        description = description + "\n\n" if description else ""

        if num_fewshot == 0:
            labeled_examples = ""
926
            fewshotex, fewshotidx, self.fewshotsource = [], [], None
cjlovering's avatar
cjlovering committed
927
928
929
930
        else:
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd)
931
                self.fewshotsource = "train"
cjlovering's avatar
cjlovering committed
932
933
934
935
936
937
938
939
            else:
                if self._fewshot_docs is None:
                    self._fewshot_docs = list(
                        self.validation_docs()
                        if self.has_validation_docs()
                        else self.test_docs()
                    )
                    if self.has_validation_docs():
940
                        self.fewshotsource = "val"
cjlovering's avatar
cjlovering committed
941
                    elif self.test_docs():
942
                        self.fewshotsource = "test"
cjlovering's avatar
cjlovering committed
943
944
945
946

                fewshotex, fewshotidx = self._get_fewshot_examples(
                    self._fewshot_docs, k=num_fewshot + 1, rnd=rnd
                )
947
                fewshotex, fewshotidx = zip(*[
cjlovering's avatar
cjlovering committed
948
949
950
                    (shot, idx)
                    for shot, idx in zip(fewshotex, fewshotidx)
                    if shot != doc
951
                ])
cjlovering's avatar
cjlovering committed
952
953
954
955
956
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex, fewshotidx = (
                    fewshotex[:num_fewshot],
                    fewshotidx[:num_fewshot],
                )
jon-tow's avatar
jon-tow committed
957
958
959
            # See Webson & Pavlick (2022) https://arxiv.org/pdf/2109.01247.pdf
            # for justification of this separator.
            example_separator = "\n###\n"
cjlovering's avatar
cjlovering committed
960
961

            labeled_examples = (
jon-tow's avatar
jon-tow committed
962
                example_separator.join(
cjlovering's avatar
cjlovering committed
963
964
965
966
967
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex
                    ]
                )
jon-tow's avatar
jon-tow committed
968
                + example_separator
cjlovering's avatar
cjlovering committed
969
970
971
972
973
974
975
976
            )

        example = self.doc_to_text(doc)
        ctx = description + labeled_examples + example
        return (
            ctx,
            {
                "fewshot_idx": fewshotidx,
977
                "fewshot_source": self.fewshotsource,
cjlovering's avatar
cjlovering committed
978
979
980
981
982
983
984
985
986
987
988
989
990
991
                "fewshot_num": num_fewshot,
                "ctx": ctx,
            },
        )

    def get_logging_info(self):
        return {
            "fixed_answer_choice_list": self.prompt.get_fixed_answer_choices_list(),
            "dataset_path": self.DATASET_PATH,
            "dataset_name": self.DATASET_NAME,
            "subset": self.SPLIT,
            "prompt_name": self.prompt.get_name(),
            "prompt_id": self.prompt.get_id(),
            "prompt_jinja": self.prompt.jinja,
992
            "prompt_original_task": self.prompt.metadata.original_task,
993
994
            # Placeholder for comment in post-processing.
            "comment": "",
cjlovering's avatar
cjlovering committed
995
996
        }

cjlovering's avatar
cjlovering committed
997
998
999
1000

class MultipleChoiceTask(Task):
    def doc_to_target(self, doc):
        return " " + doc["choices"][doc["gold"]]
Leo Gao's avatar
Leo Gao committed
1001

Leo Gao's avatar
Leo Gao committed
1002
1003
    def construct_requests(self, doc, ctx):
        lls = [
cjlovering's avatar
cjlovering committed
1004
            rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
Leo Gao's avatar
Leo Gao committed
1005
1006
1007
1008
1009
1010
1011
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

cjlovering's avatar
cjlovering committed
1012
        acc = 1.0 if np.argmax(results) == gold else 0.0
1013
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
cjlovering's avatar
cjlovering committed
1014
        acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
Leo Gao's avatar
Leo Gao committed
1015
1016

        return {
Leo Gao's avatar
Leo Gao committed
1017
1018
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
1019
        }
cjlovering's avatar
cjlovering committed
1020

Leo Gao's avatar
Leo Gao committed
1021
1022
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
1023
1024
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
1025
        }
cjlovering's avatar
cjlovering committed
1026

Leo Gao's avatar
Leo Gao committed
1027
1028
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
1029
1030
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
1031
1032
1033
        }


Jason Phang's avatar
Jason Phang committed
1034
1035
1036
1037
1038
1039
1040
1041
class PerplexityTask(Task, abc.ABC):
    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

cjlovering's avatar
cjlovering committed
1042
1043
1044
1045
1046
1047
1048
1049
1050
    def fewshot_context(
        self, doc, num_fewshot, provide_description=None, rnd=None, description=None
    ):
        assert (
            num_fewshot == 0
        ), "The number of fewshot examples must be 0 for perplexity tasks."
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
1051
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
1052
            "The `provide_description` arg will be removed in future versions. To prepend "
Jonathan Tow's avatar
Jonathan Tow committed
1053
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
1054
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
1055
        )
1056
1057
        if provide_description is not None:
            # nudge people to not specify it at all
cjlovering's avatar
cjlovering committed
1058
1059
1060
            print(
                "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
            )
1061

Jason Phang's avatar
Jason Phang committed
1062
1063
1064
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
1065
1066
1067
1068
1069
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
1070
1071

    def doc_to_text(self, doc):
1072
        return ""
Jason Phang's avatar
Jason Phang committed
1073
1074

    def doc_to_target(self, doc):
1075
        return doc
Jason Phang's avatar
Jason Phang committed
1076
1077
1078

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
1079
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
1080
1081
1082
        return req

    def process_results(self, doc, results):
cjlovering's avatar
cjlovering committed
1083
        (loglikelihood,) = results
Leo Gao's avatar
Leo Gao committed
1084
        words = self.count_words(doc)
1085
        bytes_ = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
1086
        return {
Leo Gao's avatar
Leo Gao committed
1087
            "word_perplexity": (loglikelihood, words),
1088
            "byte_perplexity": (loglikelihood, bytes_),
1089
            "bits_per_byte": (loglikelihood, bytes_),
Jason Phang's avatar
Jason Phang committed
1090
1091
1092
1093
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
1094
1095
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
1096
            "bits_per_byte": bits_per_byte,
Jason Phang's avatar
Jason Phang committed
1097
1098
        }

1099
1100
    @classmethod
    def count_bytes(cls, doc):
Leo Gao's avatar
Leo Gao committed
1101
        return len(doc.encode("utf-8"))
1102
1103
1104

    @classmethod
    def count_words(cls, doc):
cjlovering's avatar
cjlovering committed
1105
        """Downstream tasks with custom word boundaries should override this!"""
Leo Gao's avatar
Leo Gao committed
1106
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
1107

Jason Phang's avatar
Jason Phang committed
1108

Leo Gao's avatar
Leo Gao committed
1109
1110
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
cjlovering's avatar
cjlovering committed
1111
    return hashlib.sha256(dat.encode("utf-8")).hexdigest()
Leo Gao's avatar
Leo Gao committed
1112
1113


Leo Gao's avatar
Leo Gao committed
1114
1115
class CacheHook:
    def __init__(self, cachinglm):
cjlovering's avatar
cjlovering committed
1116
        if cachinglm is None:
Leo Gao's avatar
Leo Gao committed
1117
1118
1119
1120
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
cjlovering's avatar
cjlovering committed
1121

Leo Gao's avatar
Leo Gao committed
1122
1123
1124
1125
1126
1127
1128
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
1129
1130
class CachingLM:
    def __init__(self, lm, cache_db):
1131
1132
1133
1134
1135
1136
1137
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
Leo Gao's avatar
Leo Gao committed
1138
1139
        self.lm = lm
        self.cache_db = cache_db
1140
1141
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
1142
1143
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
1144
1145
1146
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
1147
1148
1149
1150
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
cjlovering's avatar
cjlovering committed
1151

Leo Gao's avatar
Leo Gao committed
1152
1153
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
1154
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
1155
1156
1157
1158
1159
1160
1161
1162
1163
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
cjlovering's avatar
cjlovering committed
1164

1165
            # actually run the LM on the requests that do not have cached results
Leo Gao's avatar
Leo Gao committed
1166
1167
1168
1169
1170
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
1171
1172
                while res[resptr] is not None:
                    resptr += 1
Leo Gao's avatar
Leo Gao committed
1173
1174
1175
1176

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
1177
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
1178
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
1179
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
1180
1181

            return res
cjlovering's avatar
cjlovering committed
1182

Leo Gao's avatar
Leo Gao committed
1183
        return fn
cjlovering's avatar
cjlovering committed
1184

Leo Gao's avatar
Leo Gao committed
1185
1186
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
1187

Jason Phang's avatar
Jason Phang committed
1188

1189
REQUEST_RETURN_LENGTHS = {
cjlovering's avatar
cjlovering committed
1190
1191
1192
    "loglikelihood": 2,
    "greedy_until": None,
    "loglikelihood_rolling": None,
1193
1194
1195
}


1196
class Request:
Leo Gao's avatar
Leo Gao committed
1197
1198
    def __init__(self, request_type, args, index=None):
        if request_type not in REQUEST_RETURN_LENGTHS.keys():
cjlovering's avatar
cjlovering committed
1199
1200
1201
            raise NotImplementedError(
                "The request type {} is not implemented!".format(request_type)
            )
Leo Gao's avatar
Leo Gao committed
1202

Leo Gao's avatar
Leo Gao committed
1203
        self.request_type = request_type
1204
1205
        self.args = args
        self.index = index
cjlovering's avatar
cjlovering committed
1206

1207
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
1208
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
cjlovering's avatar
cjlovering committed
1209
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
1210
1211
        for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
            yield Request(self.request_type, self.args, i)
cjlovering's avatar
cjlovering committed
1212

1213
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
1214
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
cjlovering's avatar
cjlovering committed
1215
            raise IndexError("This request type does not return multiple arguments!")
Leo Gao's avatar
Leo Gao committed
1216
        return Request(self.request_type, self.args, i)
cjlovering's avatar
cjlovering committed
1217

Leo Gao's avatar
Leo Gao committed
1218
    def __eq__(self, other):
cjlovering's avatar
cjlovering committed
1219
1220
1221
1222
1223
        return (
            self.request_type == other.request_type
            and self.args == other.args
            and self.index == other.index
        )
Leo Gao's avatar
Leo Gao committed
1224

Leo Gao's avatar
Leo Gao committed
1225
    def __repr__(self):
Leo Gao's avatar
Leo Gao committed
1226
        return f"Req_{self.request_type}{self.args}[{self.index}]\n"
1227

Jason Phang's avatar
Jason Phang committed
1228

Leo Gao's avatar
Leo Gao committed
1229
1230
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
1231
1232
        def fn(*args):
            return Request(attr, args)
cjlovering's avatar
cjlovering committed
1233

Leo Gao's avatar
Leo Gao committed
1234
1235
1236
1237
        return fn


rf = RequestFactory()