base.py 29.4 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import abc
2
from typing import Iterable
thefazzer's avatar
thefazzer committed
3
import numpy as np
4
import random
Leo Gao's avatar
Leo Gao committed
5
import re
6
7
8
import os
import json
import hashlib
Jonathan Tow's avatar
Jonathan Tow committed
9
import datasets
10
from sqlitedict import SqliteDict
11
from tqdm import tqdm
12
import torch
Leo Gao's avatar
Leo Gao committed
13
import torch.nn.functional as F
&'s avatar
& committed
14

15
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
16
from lm_eval import utils
17
from abc import abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
18

Jason Phang's avatar
Jason Phang committed
19

Leo Gao's avatar
Leo Gao committed
20
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
21
22
23
    def __init__(self):
        self.cache_hook = CacheHook(None)

24
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
25
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
26
27
28
        """Compute log-likelihood of generating a continuation from a context.
        Downstream tasks should attempt to use loglikelihood instead of other 
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
29

Leo Gao's avatar
Leo Gao committed
30
31
32
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Leo Gao's avatar
Leo Gao committed
33
34
                Context string. Implementations of LM must be able to handle an 
                empty context string.
Leo Gao's avatar
Leo Gao committed
35
36
37
38
39
40
41
            continuation: str
                The continuation over which log likelihood will be calculated. If 
                there is a word boundary, the space should be in the continuation. 
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
42
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
43
            isgreedy:
Jason Phang's avatar
Jason Phang committed
44
45
46
47
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

48
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
49
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
87
88
89
        """
        pass

&'s avatar
& committed
90
    # TODO: Add an optional max length
91
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
92
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
93
94
95
96
97
98
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
99
100
101
            until: [str]
                The string sequences to generate until. These string sequences 
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
102
103
104
105
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
106
        """
Leo Gao's avatar
Leo Gao committed
107
108
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
109
    @classmethod
110
111
    def create_from_arg_string(cls, arg_string, additional_config=None):
        additional_config = {} if additional_config is None else additional_config
112
113
114
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
115

Leo Gao's avatar
Leo Gao committed
116
117
118
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
119

120
class BaseLM(LM):
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    @property
    @abstractmethod
    def eot_token_id(self):
        pass

    @property
    @abstractmethod
    def max_length(self):
        pass

    @property
    @abstractmethod
    def max_gen_toks(self):
        pass

    @property
    @abstractmethod
    def batch_size(self):
        pass

    @property
    @abstractmethod
    def device(self):
        pass

147
    @abstractmethod
148
149
    def tok_encode(self, string: str): pass
    
150
    @abstractmethod
151
    def tok_decode(self, tokens: Iterable[int]): pass
Jason Phang's avatar
gpt3  
Jason Phang committed
152

153
154
    @abstractmethod
    def _model_generate(self, context, max_length, eos_token_id): pass
Jason Phang's avatar
gpt3  
Jason Phang committed
155

156
157
    @abstractmethod
    def _model_call(self, inps):
Jason Phang's avatar
gpt3  
Jason Phang committed
158
        """
159
160
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Jason Phang's avatar
gpt3  
Jason Phang committed
161

162
        returns: a torch tensor of shape [batch, sequence, vocab] with the
163
        logits returned from the model
164
165
        """
        pass
166

Leo Gao's avatar
Leo Gao committed
167
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    # TODO: enforce this somehow

    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
        for string, in tqdm(requests):
            rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
                token_list=self.tok_encode(string),
                prefix_token=self.eot_token_id,
                max_seq_len=self.max_length,
                context_len=1,
            )))

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

200
201
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
202
203
204
205
206
207
208
209
210
211
            string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
            
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
            
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

212
213
214
215
216
217
218
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
219
220
221
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
222
223
224
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
225
            return -len(toks), tuple(toks)
226
227
228
229
230
        
        # TODO: automatic (variable) batch size detection for vectorization
        reord = utils.Reorderer(requests, _collate)
        for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
            inps = []
231
            cont_toks_list = []
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
248
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
249
                # gpt2    \               \
250
251
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice
252
253
254

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
255
256
257
                    (context_enc + continuation_enc)[-(self.max_length+1):][:-1],
                    dtype=torch.long
                ).to(self.device)
258
259
260
261
262
263
264
                inplen, = inp.shape

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
                padding_length = padding_length if padding_length is not None else inplen

265
                # pad length from seq to padding_length
266
                inp = torch.cat([
267
268
                    inp,  # [seq]
                    torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device)  # [padding_length - seq]
269
270
                ], dim=0)

271
272
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
273
274
                inplens.append(inplen)

275
276
            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
            multi_logits = F.log_softmax(self._model_call(batched_inps), dim=-1).cpu()  # [batch, padding_length, vocab]
277

278
279
            for (cache_key, _, _), logits, inp, inplen, cont_toks \
                    in zip(chunk, multi_logits, inps, inplens, cont_toks_list):
280

281
282
283
                # Slice to original seq length
                contlen = len(cont_toks)
                logits = logits[inplen-contlen:inplen].unsqueeze(0)  # [1, seq, vocab]
284

285
                # Check if per-token argmax is exactly equal to continuation
286
                greedy_tokens = logits.argmax(dim=-1)
287
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)  # [1, seq]
288
289
                max_equal = (greedy_tokens == cont_toks).all()

290
291
292
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)  # [1, seq]
293

294
                # Answer: (log prob, is-exact-match)
295
296
297
298
299
300
301
302
303
304
305
306
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

        return reord.get_original(res)
    
    def greedy_until(self, requests):
        # TODO: implement fully general `until` that handles untils that are 
307
        #       multiple tokens or that span multiple tokens correctly
308
309
310
311
312
313

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
314
            return len(toks), x[0]
315
316
317
318
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
319
320
            if isinstance(until, str):
                until = [until]
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338

            primary_until, = self.tok_encode(until[0])
            
            context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device)

            cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until)

            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
            
            res.append(s)
        
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
339

Leo Gao's avatar
Leo Gao committed
340

341
class Task(abc.ABC):
&'s avatar
&amp; committed
342
343
344
345
346
347
348
349
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Jonathan Tow's avatar
Jonathan Tow committed
350

Jon Tow's avatar
Jon Tow committed
351
352
    # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
    # or a path to a custom `datasets` loading script.
Jonathan Tow's avatar
Jonathan Tow committed
353
354
355
356
357
    DATASET_PATH: str = None

    # The name of a subset within `DATASET_PATH`.
    DATASET_NAME: str = None

Jon Tow's avatar
Jon Tow committed
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        """
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
        """
        self.download(data_dir, cache_dir, download_mode)
382
        self._training_docs = None
383
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
384

Jon Tow's avatar
Jon Tow committed
385
    def download(self, data_dir=None, cache_dir=None, download_mode=None):
Jonathan Tow's avatar
Jonathan Tow committed
386
387
388
        """ Downloads and returns the task dataset.
        Override this method to download the dataset from a custom API.

Jon Tow's avatar
Jon Tow committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        :param data_dir: str
            Stores the path to a local folder containing the `Task`'s data files.
            Use this to specify the path to manually downloaded data (usually when
            the dataset is not publicly accessible).
        :param cache_dir: str
            The directory to read/write the `Task` dataset. This follows the
            HuggingFace `datasets` API with the default cache directory located at:
                `~/.cache/huggingface/datasets`
            NOTE: You can change the cache location globally for a given process
            by setting the shell environment variable, `HF_DATASETS_CACHE`,
            to another directory:
                `export HF_DATASETS_CACHE="/path/to/another/directory"`
        :param download_mode: datasets.DownloadMode
            How to treat pre-existing `Task` downloads and data.
            - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
                Reuse download and reuse dataset.
            - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
                Reuse download with fresh dataset.
            - `datasets.DownloadMode.FORCE_REDOWNLOAD`
                Fresh download and fresh dataset.
Jonathan Tow's avatar
Jonathan Tow committed
409
410
411
412
        """
        self.dataset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
Jon Tow's avatar
Jon Tow committed
413
414
415
            data_dir=data_dir,
            cache_dir=cache_dir,
            download_mode=download_mode
Jonathan Tow's avatar
Jonathan Tow committed
416
        )
sdtblck's avatar
sdtblck committed
417

418
419
420
421
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return False

422
    @abstractmethod
423
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
424
        """Whether the task has a training set"""
425
        pass
426

427
    @abstractmethod
428
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
429
430
431
        """Whether the task has a validation set"""
        pass

432
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
433
434
    def has_test_docs(self):
        """Whether the task has a test set"""
435
436
        pass

Leo Gao's avatar
Leo Gao committed
437
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
438
439
440
441
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
442
        return []
443

Leo Gao's avatar
Leo Gao committed
444
    def validation_docs(self):
445
446
447
448
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
449
        return []
450

Leo Gao's avatar
Leo Gao committed
451
    def test_docs(self):
452
453
454
455
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
456
        return []
Leo Gao's avatar
Leo Gao committed
457

Jon Tow's avatar
Jon Tow committed
458
459
460
461
    def _process_doc(self, doc):
        """
        Override this to process (detokenize, strip, replace, etc.) individual
        documents. This can be used in a map over documents of a data split.
Jon Tow's avatar
Jon Tow committed
462
        E.g. `map(self._process_doc, self.dataset["validation"])`
Jon Tow's avatar
Jon Tow committed
463
464
465
466
467
468

        :return: dict
            The processed version of the specified `doc`.
        """
        return doc

469
    def fewshot_examples(self, k, rnd):
470
471
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
472

Leo Gao's avatar
Leo Gao committed
473
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
474

475
476
477
478
    def doc_to_decontamination_query(self, doc):
        print("Override doc_to_decontamination_query with document specific decontamination query.")
        assert(False)

479
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
480
481
482
    def doc_to_text(self, doc):
        pass

483
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
484
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
485
        pass
Leo Gao's avatar
Leo Gao committed
486

487
    @abstractmethod
488
    def construct_requests(self, doc, ctx):
Leo Gao's avatar
Leo Gao committed
489
490
491
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

492
493
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
494
        :param ctx: str
495
496
497
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
Leo Gao's avatar
Leo Gao committed
498
        """
Leo Gao's avatar
Leo Gao committed
499
        pass
500

501
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
502
    def process_results(self, doc, results):
Leo Gao's avatar
Update  
Leo Gao committed
503
        """Take a single document and the LM results and evaluates, returning a 
504
505
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
506
507
508
509
510

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
511
        """
Leo Gao's avatar
Leo Gao committed
512
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
513

514
    @abstractmethod
515
516
    def aggregation(self):
        """
&'s avatar
&amp; committed
517
        :returns: {str: [metric_score] -> float}
518
            A dictionary where keys are the names of submetrics and values are 
&'s avatar
&amp; committed
519
            functions that aggregate a list of metric scores
520
521
522
        """
        pass

523
    @abstractmethod
524
525
526
527
528
529
530
531
    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
532
    def fewshot_description(self):
533
534
        import warnings
        warnings.warn(
Jonathan Tow's avatar
Jonathan Tow committed
535
            "`fewshot_description` will be removed in futures versions. Pass "
536
537
            "any custom descriptions to the `evaluate` function instead.",
            DeprecationWarning)
Jason Phang's avatar
checkin  
Jason Phang committed
538
539
        return ""

540
    @utils.positional_deprecated
541
542
543
544
545
546
547
548
549
550
551
552
    def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
        """ Returns a fewshot context string that is made up of a prepended description
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
553
            WARNING: This is currently a required arg although it's optionalized with a default `None`.
554
555
556
557
558
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
559
        assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`"
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
560
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
561
            "The `provide_description` arg will be removed in future versions. To prepend "
562
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
563
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
564
        )
565
566
567
568
        if provide_description is not None:
            # nudge people to not specify it at all
            print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")

569
        description = description + "\n\n" if description else ""
570

571
572
        if num_fewshot == 0:
            labeled_examples = ""
573
        else:
574
575
576
577
578
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
579
580
581
                    self._fewshot_docs = list(
                        self.validation_docs() if self.has_validation_docs() else self.test_docs()
                    )
582

583
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
584

585
586
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
587

588
            labeled_examples = "\n\n".join(
589
                [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]
590
            ) + "\n\n"
Leo Gao's avatar
Update  
Leo Gao committed
591

592
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
593
594
595
        return description + labeled_examples + example


Jon Tow's avatar
Jon Tow committed
596
597
class MultipleChoiceTask(Task):

Leo Gao's avatar
Leo Gao committed
598
599
600
    def doc_to_target(self, doc):
        return " " + doc['choices'][doc['gold']]

Leo Gao's avatar
Leo Gao committed
601
602
603
604
605
606
607
608
609
610
611
    def construct_requests(self, doc, ctx):
        lls = [
            rf.loglikelihood(ctx, " {}".format(choice))[0]
            for choice in doc['choices']
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Leo Gao's avatar
Leo Gao committed
612
        acc = 1. if np.argmax(results) == gold else 0.
613
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Leo Gao's avatar
Leo Gao committed
614
        acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
Leo Gao's avatar
Leo Gao committed
615
616

        return {
Leo Gao's avatar
Leo Gao committed
617
618
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
619
620
621
622
        }
    
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
623
624
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
625
626
627
628
        }
    
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
629
630
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
631
632
633
        }


Jason Phang's avatar
Jason Phang committed
634
635
class PerplexityTask(Task, abc.ABC):

636
637
638
639
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return True

Jason Phang's avatar
Jason Phang committed
640
641
642
643
644
645
646
    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

647
    def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
Jonathan Tow's avatar
Jonathan Tow committed
648
649
        assert num_fewshot == 0, "The number of fewshot examples must be 0 for perplexity tasks."
        assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
650
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
651
            "The `provide_description` arg will be removed in future versions. To prepend "
Jonathan Tow's avatar
Jonathan Tow committed
652
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
653
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
654
        )
655
656
657
658
        if provide_description is not None:
            # nudge people to not specify it at all
            print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")

Jason Phang's avatar
Jason Phang committed
659
660
661
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
662
663
664
665
666
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
667

668
669
670
    def doc_to_decontamination_query(self, doc):
        return doc

Jason Phang's avatar
Jason Phang committed
671
    def doc_to_text(self, doc):
672
        return ""
Jason Phang's avatar
Jason Phang committed
673
674

    def doc_to_target(self, doc):
675
        return doc
Jason Phang's avatar
Jason Phang committed
676
677
678

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
679
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
680
681
682
683
        return req

    def process_results(self, doc, results):
        loglikelihood, = results
Leo Gao's avatar
Leo Gao committed
684
        words = self.count_words(doc)
685
        bytes_ = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
686
        return {
Leo Gao's avatar
Leo Gao committed
687
            "word_perplexity": (loglikelihood, words),
688
            "byte_perplexity": (loglikelihood, bytes_),
689
            "bits_per_byte": (loglikelihood, bytes_),
Jason Phang's avatar
Jason Phang committed
690
691
692
693
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
694
695
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
696
            "bits_per_byte": bits_per_byte,
Jason Phang's avatar
Jason Phang committed
697
698
        }

699
700
    @classmethod
    def count_bytes(cls, doc):
Leo Gao's avatar
Leo Gao committed
701
        return len(doc.encode("utf-8"))
702
703
704

    @classmethod
    def count_words(cls, doc):
Leo Gao's avatar
Leo Gao committed
705
        """ Downstream tasks with custom word boundaries should override this! """
Leo Gao's avatar
Leo Gao committed
706
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
707

Jason Phang's avatar
Jason Phang committed
708

Leo Gao's avatar
Leo Gao committed
709
710
711
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
    return hashlib.sha256(dat.encode('utf-8')).hexdigest()
Leo Gao's avatar
Leo Gao committed
712
713


Leo Gao's avatar
Leo Gao committed
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
class CacheHook:
    def __init__(self, cachinglm):
        if cachinglm is None: 
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
    
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
729
730
class CachingLM:
    def __init__(self, lm, cache_db):
731
732
733
734
735
736
737
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
Leo Gao's avatar
Leo Gao committed
738
739
        self.lm = lm
        self.cache_db = cache_db
740
741
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
742
743
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
744
745
746
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
747
748
749
750
751
752
753
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
            
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
754
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
755
756
757
758
759
760
761
762
763
764
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
            
765
            # actually run the LM on the requests that do not have cached results
Leo Gao's avatar
Leo Gao committed
766
767
768
769
770
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
771
772
                while res[resptr] is not None:
                    resptr += 1
Leo Gao's avatar
Leo Gao committed
773
774
775
776

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
777
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
778
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
779
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
780
781
782

            return res
        return fn
Leo Gao's avatar
Leo Gao committed
783
784
785
    
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
786

Jason Phang's avatar
Jason Phang committed
787

788
789
790
791
792
793
794
REQUEST_RETURN_LENGTHS = {
    'loglikelihood': 2,
    'greedy_until': None,
    'loglikelihood_rolling': None,
}


795
class Request:
Leo Gao's avatar
Leo Gao committed
796
797
798
    def __init__(self, request_type, args, index=None):
        if request_type not in REQUEST_RETURN_LENGTHS.keys():
            raise NotImplementedError('The request type {} is not implemented!'.format(request_type))
Leo Gao's avatar
Leo Gao committed
799

Leo Gao's avatar
Leo Gao committed
800
        self.request_type = request_type
801
802
803
804
        self.args = args
        self.index = index
    
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
805
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Leo Gao's avatar
Leo Gao committed
806
            raise IndexError('This request type does not return multiple arguments!')
Leo Gao's avatar
Leo Gao committed
807
808
        for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
            yield Request(self.request_type, self.args, i)
809
810
    
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
811
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Leo Gao's avatar
Leo Gao committed
812
            raise IndexError('This request type does not return multiple arguments!')
Leo Gao's avatar
Leo Gao committed
813
        return Request(self.request_type, self.args, i)
Leo Gao's avatar
Leo Gao committed
814
815
    
    def __eq__(self, other):
Leo Gao's avatar
Leo Gao committed
816
        return self.request_type == other.request_type and self.args == other.args and self.index == other.index
Leo Gao's avatar
Leo Gao committed
817

Leo Gao's avatar
Leo Gao committed
818
    def __repr__(self):
Leo Gao's avatar
Leo Gao committed
819
        return f"Req_{self.request_type}{self.args}[{self.index}]\n"
820

Jason Phang's avatar
Jason Phang committed
821

Leo Gao's avatar
Leo Gao committed
822
823
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
824
825
        def fn(*args):
            return Request(attr, args)
Leo Gao's avatar
Leo Gao committed
826
827
828
829
        return fn


rf = RequestFactory()