base.py 25.8 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import abc
2
from typing import Iterable
thefazzer's avatar
thefazzer committed
3
import numpy as np
4
import random
Leo Gao's avatar
Leo Gao committed
5
import re
6
7
8
9
import os
import json
import hashlib
from sqlitedict import SqliteDict
10
from tqdm import tqdm
11
import torch
Leo Gao's avatar
Leo Gao committed
12
import torch.nn.functional as F
&'s avatar
& committed
13

14
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean
15
from lm_eval import utils
16
from abc import abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
17

Jason Phang's avatar
Jason Phang committed
18

Leo Gao's avatar
Leo Gao committed
19
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
20
21
22
    def __init__(self):
        self.cache_hook = CacheHook(None)

23
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
24
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
25
26
27
        """Compute log-likelihood of generating a continuation from a context.
        Downstream tasks should attempt to use loglikelihood instead of other 
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
28

Leo Gao's avatar
Leo Gao committed
29
30
31
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Leo Gao's avatar
Leo Gao committed
32
33
                Context string. Implementations of LM must be able to handle an 
                empty context string.
Leo Gao's avatar
Leo Gao committed
34
35
36
37
38
39
40
            continuation: str
                The continuation over which log likelihood will be calculated. If 
                there is a word boundary, the space should be in the continuation. 
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
41
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
42
            isgreedy:
Jason Phang's avatar
Jason Phang committed
43
44
45
46
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

47
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
48
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
86
87
88
        """
        pass

&'s avatar
& committed
89
    # TODO: Add an optional max length
90
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
91
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
92
93
94
95
96
97
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
98
99
100
            until: [str]
                The string sequences to generate until. These string sequences 
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
101
102
103
104
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
105
        """
Leo Gao's avatar
Leo Gao committed
106
107
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
108
    @classmethod
109
110
    def create_from_arg_string(cls, arg_string, additional_config=None):
        additional_config = {} if additional_config is None else additional_config
111
112
113
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
114

Leo Gao's avatar
Leo Gao committed
115
116
117
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
118

119
class BaseLM(LM):
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

    @property
    @abstractmethod
    def eot_token_id(self):
        pass

    @property
    @abstractmethod
    def max_length(self):
        pass

    @property
    @abstractmethod
    def max_gen_toks(self):
        pass

    @property
    @abstractmethod
    def batch_size(self):
        pass

    @property
    @abstractmethod
    def device(self):
        pass

146
    @abstractmethod
147
148
    def tok_encode(self, string: str): pass
    
149
    @abstractmethod
150
    def tok_decode(self, tokens: Iterable[int]): pass
Jason Phang's avatar
gpt3  
Jason Phang committed
151

152
153
    @abstractmethod
    def _model_generate(self, context, max_length, eos_token_id): pass
Jason Phang's avatar
gpt3  
Jason Phang committed
154

155
156
    @abstractmethod
    def _model_call(self, inps):
Jason Phang's avatar
gpt3  
Jason Phang committed
157
        """
158
159
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Jason Phang's avatar
gpt3  
Jason Phang committed
160

161
        returns: a torch tensor of shape [batch, sequence, vocab] with the
162
        logits returned from the model
163
164
        """
        pass
165

Leo Gao's avatar
Leo Gao committed
166
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    # TODO: enforce this somehow

    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
        for string, in tqdm(requests):
            rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
                token_list=self.tok_encode(string),
                prefix_token=self.eot_token_id,
                max_seq_len=self.max_length,
                context_len=1,
            )))

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

199
200
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
201
202
203
204
205
206
207
208
209
210
            string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
            
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
            
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

211
212
213
214
215
216
217
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
218
219
220
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
221
222
223
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
224
            return -len(toks), tuple(toks)
225
226
227
228
229
        
        # TODO: automatic (variable) batch size detection for vectorization
        reord = utils.Reorderer(requests, _collate)
        for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
            inps = []
230
            cont_toks_list = []
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
247
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
248
                # gpt2    \               \
249
250
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice
251
252
253

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
254
255
256
                    (context_enc + continuation_enc)[-(self.max_length+1):][:-1],
                    dtype=torch.long
                ).to(self.device)
257
258
259
260
261
262
263
                inplen, = inp.shape

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
                padding_length = padding_length if padding_length is not None else inplen

264
                # pad length from seq to padding_length
265
                inp = torch.cat([
266
267
                    inp,  # [seq]
                    torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device)  # [padding_length - seq]
268
269
                ], dim=0)

270
271
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
272
273
                inplens.append(inplen)

274
275
            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
            multi_logits = F.log_softmax(self._model_call(batched_inps), dim=-1).cpu()  # [batch, padding_length, vocab]
276

277
278
            for (cache_key, _, _), logits, inp, inplen, cont_toks \
                    in zip(chunk, multi_logits, inps, inplens, cont_toks_list):
279

280
281
282
                # Slice to original seq length
                contlen = len(cont_toks)
                logits = logits[inplen-contlen:inplen].unsqueeze(0)  # [1, seq, vocab]
283

284
                # Check if per-token argmax is exactly equal to continuation
285
                greedy_tokens = logits.argmax(dim=-1)
286
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)  # [1, seq]
287
288
                max_equal = (greedy_tokens == cont_toks).all()

289
290
291
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)  # [1, seq]
292

293
                # Answer: (log prob, is-exact-match)
294
295
296
297
298
299
300
301
302
303
304
305
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

        return reord.get_original(res)
    
    def greedy_until(self, requests):
        # TODO: implement fully general `until` that handles untils that are 
306
        #       multiple tokens or that span multiple tokens correctly
307
308
309
310
311
312

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
313
            return len(toks), x[0]
314
315
316
317
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
318
319
            if isinstance(until, str):
                until = [until]
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337

            primary_until, = self.tok_encode(until[0])
            
            context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device)

            cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until)

            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
            
            res.append(s)
        
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
338

Leo Gao's avatar
Leo Gao committed
339

340
class Task(abc.ABC):
&'s avatar
&amp; committed
341
342
343
344
345
346
347
348
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Leo Gao's avatar
Leo Gao committed
349
350
    def __init__(self):
        self.download()
351
        self._training_docs = None
352
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
353
354
355
356
357

    def download(self):
        """Downloads the task dataset if necessary"""
        pass

358
    @abstractmethod
359
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
360
        """Whether the task has a training set"""
361
        pass
362

363
    @abstractmethod
364
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
365
366
367
        """Whether the task has a validation set"""
        pass

368
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
369
370
    def has_test_docs(self):
        """Whether the task has a test set"""
371
372
        pass

Leo Gao's avatar
Leo Gao committed
373
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
374
375
376
377
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
378
        return []
379

Leo Gao's avatar
Leo Gao committed
380
    def validation_docs(self):
381
382
383
384
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
385
        return []
386

Leo Gao's avatar
Leo Gao committed
387
    def test_docs(self):
388
389
390
391
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
392
        return []
Leo Gao's avatar
Leo Gao committed
393

394
    def fewshot_examples(self, k, rnd):
395
396
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
397

Leo Gao's avatar
Leo Gao committed
398
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
399

400
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
401
402
403
    def doc_to_text(self, doc):
        pass

404
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
405
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
406
        pass
Leo Gao's avatar
Leo Gao committed
407

408
    @abstractmethod
409
    def construct_requests(self, doc, ctx):
Leo Gao's avatar
Leo Gao committed
410
411
412
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

413
414
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
415
        :param ctx: str
416
417
418
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
Leo Gao's avatar
Leo Gao committed
419
        """
Leo Gao's avatar
Leo Gao committed
420
        pass
421

422
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
423
    def process_results(self, doc, results):
Leo Gao's avatar
Update  
Leo Gao committed
424
        """Take a single document and the LM results and evaluates, returning a 
425
426
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
427
428
429
430
431

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
432
        """
Leo Gao's avatar
Leo Gao committed
433
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
434

435
    @abstractmethod
436
437
    def aggregation(self):
        """
&'s avatar
&amp; committed
438
        :returns: {str: [metric_score] -> float}
439
            A dictionary where keys are the names of submetrics and values are 
&'s avatar
&amp; committed
440
            functions that aggregate a list of metric scores
441
442
443
        """
        pass

444
    @abstractmethod
445
446
447
448
449
450
451
452
    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
453
    def fewshot_description(self):
454
455
        import warnings
        warnings.warn(
Jonathan Tow's avatar
Jonathan Tow committed
456
            "`fewshot_description` will be removed in futures versions. Pass "
457
458
            "any custom descriptions to the `evaluate` function instead.",
            DeprecationWarning)
Jason Phang's avatar
checkin  
Jason Phang committed
459
460
        return ""

461
    @utils.positional_deprecated
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
        """ Returns a fewshot context string that is made up of a prepended description
        (if provided), the `num_fewshot` number of examples, and an appended prompt example.

        :param doc: str
            The document as returned from training_docs, validation_docs, or test_docs.
        :param num_fewshot: int
            The number of fewshot examples to provide in the returned context string.
        :param provide_description: bool
            Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
        :param rnd: random.Random
            The pseudo-random number generator used to randomly sample examples.
            WARNING: If you do not provide a `rnd` arg, a default `random.Random`
            object will be created and seeded with this Task's name attribute, `__name__`.
        :param description: str
            The task's description that will be prepended to the fewshot examples.
        :returns: str
            The fewshot context.
        """
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
481
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
482
            "The `provide_description` arg will be removed in future versions. To prepend "
483
            "a custom description to the context, supply the corresponding string via the "
Jonathan Tow's avatar
Jonathan Tow committed
484
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
485
        )
486
487
488
489
        if provide_description is not None:
            # nudge people to not specify it at all
            print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")

490
        description = description + "\n\n" if description else ""
491

492
493
494
495
496
        # TODO (jon-tow): Remove this default `rand` behaviour after `provide_description` is removed and remove the respective `rand` arg warning in the docs above.
        if rnd is None:
            rnd = random.Random()
            rnd.seed(self.__name__)

497
498
        if num_fewshot == 0:
            labeled_examples = ""
499
        else:
500
501
502
503
504
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
505
506
507
                    self._fewshot_docs = list(
                        self.validation_docs() if self.has_validation_docs() else self.test_docs()
                    )
508

509
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
510

511
512
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
513

514
            labeled_examples = "\n\n".join(
515
                [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]
516
            ) + "\n\n"
Leo Gao's avatar
Update  
Leo Gao committed
517

518
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
519
520
521
        return description + labeled_examples + example


522
class MultipleChoiceTask(Task, abc.ABC):
Leo Gao's avatar
Leo Gao committed
523
524
525
    def doc_to_target(self, doc):
        return " " + doc['choices'][doc['gold']]

Leo Gao's avatar
Leo Gao committed
526
527
528
529
530
531
532
533
534
535
536
    def construct_requests(self, doc, ctx):
        lls = [
            rf.loglikelihood(ctx, " {}".format(choice))[0]
            for choice in doc['choices']
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Leo Gao's avatar
Leo Gao committed
537
        acc = 1. if np.argmax(results) == gold else 0.
538
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Leo Gao's avatar
Leo Gao committed
539
        acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
Leo Gao's avatar
Leo Gao committed
540
541

        return {
Leo Gao's avatar
Leo Gao committed
542
543
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
544
545
546
547
        }
    
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
548
549
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
550
551
552
553
        }
    
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
554
555
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
556
557
558
        }


Jason Phang's avatar
Jason Phang committed
559
560
561
562
563
564
565
566
567
class PerplexityTask(Task, abc.ABC):

    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

568
    def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
Jason Phang's avatar
Jason Phang committed
569
        assert num_fewshot == 0
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
570
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
571
572
573
            "The `provide_description` arg will be removed in future versions. To prepend "
            "a custom description to the context, supply the corresponding string via the  "
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
574
        )
575
576
577
578
579
580
581
582
583
        if provide_description is not None:
            # nudge people to not specify it at all
            print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")

        # TODO (jon-tow): Remove this default `rand` behaviour after `provide_description` is removed and remove the respective `rand` arg warning in the docs above.
        if rnd is None:
            rnd = random.Random()
            rnd.seed(self.__name__)

Jason Phang's avatar
Jason Phang committed
584
585
586
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
587
588
589
590
591
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
592
593

    def doc_to_text(self, doc):
594
        return ""
Jason Phang's avatar
Jason Phang committed
595
596

    def doc_to_target(self, doc):
597
        return doc
Jason Phang's avatar
Jason Phang committed
598
599
600

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
601
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
602
603
604
605
        return req

    def process_results(self, doc, results):
        loglikelihood, = results
Leo Gao's avatar
Leo Gao committed
606
        words = self.count_words(doc)
607
        bytes_ = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
608
        return {
Leo Gao's avatar
Leo Gao committed
609
            "word_perplexity": (loglikelihood, words),
610
            "byte_perplexity": (loglikelihood, bytes_),
Leo Gao's avatar
Leo Gao committed
611
            "bits_per_byte": (-loglikelihood, self.count_bytes(doc))
Jason Phang's avatar
Jason Phang committed
612
613
614
615
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
616
617
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
Leo Gao's avatar
Leo Gao committed
618
            "bits_per_byte": weighted_mean
Jason Phang's avatar
Jason Phang committed
619
620
        }

621
622
    @classmethod
    def count_bytes(cls, doc):
Leo Gao's avatar
Leo Gao committed
623
        return len(doc.encode("utf-8"))
624
625
626

    @classmethod
    def count_words(cls, doc):
Leo Gao's avatar
Leo Gao committed
627
        """ Downstream tasks with custom word boundaries should override this! """
Leo Gao's avatar
Leo Gao committed
628
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
629

Jason Phang's avatar
Jason Phang committed
630

Leo Gao's avatar
Leo Gao committed
631
632
633
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
    return hashlib.sha256(dat.encode('utf-8')).hexdigest()
Leo Gao's avatar
Leo Gao committed
634
635


Leo Gao's avatar
Leo Gao committed
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
class CacheHook:
    def __init__(self, cachinglm):
        if cachinglm is None: 
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
    
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
651
652
class CachingLM:
    def __init__(self, lm, cache_db):
653
654
655
656
657
658
659
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
Leo Gao's avatar
Leo Gao committed
660
661
        self.lm = lm
        self.cache_db = cache_db
662
663
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
664
665
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
666
667
668
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
669
670
671
672
673
674
675
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
            
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
676
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
677
678
679
680
681
682
683
684
685
686
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
            
687
            # actually run the LM on the requests that do not have cached results
Leo Gao's avatar
Leo Gao committed
688
689
690
691
692
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
693
694
                while res[resptr] is not None:
                    resptr += 1
Leo Gao's avatar
Leo Gao committed
695
696
697
698

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
699
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
700
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
701
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
702
703
704

            return res
        return fn
Leo Gao's avatar
Leo Gao committed
705
706
707
    
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
708

Jason Phang's avatar
Jason Phang committed
709

710
711
712
713
714
715
716
REQUEST_RETURN_LENGTHS = {
    'loglikelihood': 2,
    'greedy_until': None,
    'loglikelihood_rolling': None,
}


717
class Request:
Leo Gao's avatar
Leo Gao committed
718
719
720
    def __init__(self, request_type, args, index=None):
        if request_type not in REQUEST_RETURN_LENGTHS.keys():
            raise NotImplementedError('The request type {} is not implemented!'.format(request_type))
Leo Gao's avatar
Leo Gao committed
721

Leo Gao's avatar
Leo Gao committed
722
        self.request_type = request_type
723
724
725
726
        self.args = args
        self.index = index
    
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
727
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Leo Gao's avatar
Leo Gao committed
728
            raise IndexError('This request type does not return multiple arguments!')
Leo Gao's avatar
Leo Gao committed
729
730
        for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
            yield Request(self.request_type, self.args, i)
731
732
    
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
733
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Leo Gao's avatar
Leo Gao committed
734
            raise IndexError('This request type does not return multiple arguments!')
Leo Gao's avatar
Leo Gao committed
735
        return Request(self.request_type, self.args, i)
Leo Gao's avatar
Leo Gao committed
736
737
    
    def __eq__(self, other):
Leo Gao's avatar
Leo Gao committed
738
        return self.request_type == other.request_type and self.args == other.args and self.index == other.index
Leo Gao's avatar
Leo Gao committed
739

Leo Gao's avatar
Leo Gao committed
740
    def __repr__(self):
Leo Gao's avatar
Leo Gao committed
741
        return f"Req_{self.request_type}{self.args}[{self.index}]\n"
742

Jason Phang's avatar
Jason Phang committed
743

Leo Gao's avatar
Leo Gao committed
744
745
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
746
747
        def fn(*args):
            return Request(attr, args)
Leo Gao's avatar
Leo Gao committed
748
749
750
751
        return fn


rf = RequestFactory()