base.py 23.7 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import abc
2
from typing import Iterable
thefazzer's avatar
thefazzer committed
3
import numpy as np
Leo Gao's avatar
Leo Gao committed
4
import re
5
6
7
8
import os
import json
import hashlib
from sqlitedict import SqliteDict
9
from tqdm import tqdm
10
import torch
Leo Gao's avatar
Leo Gao committed
11
import torch.nn.functional as F
&'s avatar
& committed
12

13
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean
14
from lm_eval import utils
15
from abc import abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
16

Jason Phang's avatar
Jason Phang committed
17

Leo Gao's avatar
Leo Gao committed
18
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
19
20
21
    def __init__(self):
        self.cache_hook = CacheHook(None)

22
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
23
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
24
25
26
        """Compute log-likelihood of generating a continuation from a context.
        Downstream tasks should attempt to use loglikelihood instead of other 
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
27

Leo Gao's avatar
Leo Gao committed
28
29
30
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Leo Gao's avatar
Leo Gao committed
31
32
                Context string. Implementations of LM must be able to handle an 
                empty context string.
Leo Gao's avatar
Leo Gao committed
33
34
35
36
37
38
39
            continuation: str
                The continuation over which log likelihood will be calculated. If 
                there is a word boundary, the space should be in the continuation. 
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
40
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
41
            isgreedy:
Jason Phang's avatar
Jason Phang committed
42
43
44
45
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

46
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
47
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
85
86
87
        """
        pass

&'s avatar
& committed
88
    # TODO: Add an optional max length
89
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
90
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
91
92
93
94
95
96
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
97
98
99
            until: [str]
                The string sequences to generate until. These string sequences 
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
100
101
102
103
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
104
        """
Leo Gao's avatar
Leo Gao committed
105
106
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
107
    @classmethod
108
109
    def create_from_arg_string(cls, arg_string, additional_config=None):
        additional_config = {} if additional_config is None else additional_config
110
111
112
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
113

Leo Gao's avatar
Leo Gao committed
114
115
116
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
117

118
class BaseLM(LM):
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

    @property
    @abstractmethod
    def eot_token_id(self):
        pass

    @property
    @abstractmethod
    def max_length(self):
        pass

    @property
    @abstractmethod
    def max_gen_toks(self):
        pass

    @property
    @abstractmethod
    def batch_size(self):
        pass

    @property
    @abstractmethod
    def device(self):
        pass

145
    @abstractmethod
146
147
    def tok_encode(self, string: str): pass
    
148
    @abstractmethod
149
    def tok_decode(self, tokens: Iterable[int]): pass
Jason Phang's avatar
gpt3  
Jason Phang committed
150

151
152
    @abstractmethod
    def _model_generate(self, context, max_length, eos_token_id): pass
Jason Phang's avatar
gpt3  
Jason Phang committed
153

154
155
    @abstractmethod
    def _model_call(self, inps):
Jason Phang's avatar
gpt3  
Jason Phang committed
156
        """
157
158
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
Jason Phang's avatar
gpt3  
Jason Phang committed
159

160
        returns: a torch tensor of shape [batch, sequence, vocab] with the
161
        logits returned from the model
162
163
        """
        pass
164

Leo Gao's avatar
Leo Gao committed
165
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    # TODO: enforce this somehow

    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
        for string, in tqdm(requests):
            rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
                token_list=self.tok_encode(string),
                prefix_token=self.eot_token_id,
                max_seq_len=self.max_length,
                context_len=1,
            )))

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

198
199
            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
            # that
200
201
202
203
204
205
206
207
208
209
            string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
            
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
            
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

210
211
212
213
214
215
216
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
217
218
219
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
220
221
222
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
223
            return -len(toks), tuple(toks)
224
225
226
227
228
        
        # TODO: automatic (variable) batch size detection for vectorization
        reord = utils.Reorderer(requests, _collate)
        for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
            inps = []
229
            cont_toks_list = []
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
246
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
247
                # gpt2    \               \
248
249
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice
250
251
252

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
253
254
255
                    (context_enc + continuation_enc)[-(self.max_length+1):][:-1],
                    dtype=torch.long
                ).to(self.device)
256
257
258
259
260
261
262
                inplen, = inp.shape

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
                padding_length = padding_length if padding_length is not None else inplen

263
                # pad length from seq to padding_length
264
                inp = torch.cat([
265
266
                    inp,  # [seq]
                    torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device)  # [padding_length - seq]
267
268
                ], dim=0)

269
270
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
271
272
                inplens.append(inplen)

273
274
            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
            multi_logits = F.log_softmax(self._model_call(batched_inps), dim=-1).cpu()  # [batch, padding_length, vocab]
275

276
277
            for (cache_key, _, _), logits, inp, inplen, cont_toks \
                    in zip(chunk, multi_logits, inps, inplens, cont_toks_list):
278

279
280
281
                # Slice to original seq length
                contlen = len(cont_toks)
                logits = logits[inplen-contlen:inplen].unsqueeze(0)  # [1, seq, vocab]
282

283
                # Check if per-token argmax is exactly equal to continuation
284
                greedy_tokens = logits.argmax(dim=-1)
285
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)  # [1, seq]
286
287
                max_equal = (greedy_tokens == cont_toks).all()

288
289
290
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)  # [1, seq]
291

292
                # Answer: (log prob, is-exact-match)
293
294
295
296
297
298
299
300
301
302
303
304
                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

        return reord.get_original(res)
    
    def greedy_until(self, requests):
        # TODO: implement fully general `until` that handles untils that are 
305
        #       multiple tokens or that span multiple tokens correctly
306
307
308
309
310
311

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
312
            return len(toks), x[0]
313
314
315
316
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
317
318
            if isinstance(until, str):
                until = [until]
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336

            primary_until, = self.tok_encode(until[0])
            
            context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device)

            cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until)

            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
            
            res.append(s)
        
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
337

Leo Gao's avatar
Leo Gao committed
338

339
class Task(abc.ABC):
&'s avatar
&amp; committed
340
341
342
343
344
345
346
347
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Leo Gao's avatar
Leo Gao committed
348
349
    def __init__(self):
        self.download()
350
        self._training_docs = None
351
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
352
353
354
355
356

    def download(self):
        """Downloads the task dataset if necessary"""
        pass

357
    @abstractmethod
358
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
359
        """Whether the task has a training set"""
360
        pass
361

362
    @abstractmethod
363
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
364
365
366
        """Whether the task has a validation set"""
        pass

367
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
368
369
    def has_test_docs(self):
        """Whether the task has a test set"""
370
371
        pass

Leo Gao's avatar
Leo Gao committed
372
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
373
374
375
376
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
377
        return []
378

Leo Gao's avatar
Leo Gao committed
379
    def validation_docs(self):
380
381
382
383
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
384
        return []
385

Leo Gao's avatar
Leo Gao committed
386
    def test_docs(self):
387
388
389
390
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
391
        return []
Leo Gao's avatar
Leo Gao committed
392

393
    def fewshot_examples(self, k, rnd):
394
395
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
396

Leo Gao's avatar
Leo Gao committed
397
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
398

399
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
400
401
402
    def doc_to_text(self, doc):
        pass

403
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
404
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
405
        pass
Leo Gao's avatar
Leo Gao committed
406

407
    @abstractmethod
408
    def construct_requests(self, doc, ctx):
Leo Gao's avatar
Leo Gao committed
409
410
411
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

412
413
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
414
        :param ctx: str
415
416
417
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
Leo Gao's avatar
Leo Gao committed
418
        """
Leo Gao's avatar
Leo Gao committed
419
        pass
420

421
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
422
    def process_results(self, doc, results):
Leo Gao's avatar
Update  
Leo Gao committed
423
        """Take a single document and the LM results and evaluates, returning a 
424
425
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
426
427
428
429
430

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
431
        """
Leo Gao's avatar
Leo Gao committed
432
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
433

434
    @abstractmethod
435
436
    def aggregation(self):
        """
&'s avatar
&amp; committed
437
        :returns: {str: [metric_score] -> float}
438
            A dictionary where keys are the names of submetrics and values are 
&'s avatar
&amp; committed
439
            functions that aggregate a list of metric scores
440
441
442
        """
        pass

443
    @abstractmethod
444
445
446
447
448
449
450
451
    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
452
    def fewshot_description(self):
453
454
        import warnings
        warnings.warn(
Jonathan Tow's avatar
Jonathan Tow committed
455
            "`fewshot_description` will be removed in futures versions. Pass "
456
457
            "any custom descriptions to the `evaluate` function instead.",
            DeprecationWarning)
Jason Phang's avatar
checkin  
Jason Phang committed
458
459
        return ""

Jonathan Tow's avatar
Merge  
Jonathan Tow committed
460
461
    def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
462
463
464
            "The `provide_description` arg will be removed in future versions. To prepend "
            "a custom description to the context, supply the corresponding string via the  "
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
465
        )
466
        description = description + "\n\n" if description else ""
467

468
469
        if num_fewshot == 0:
            labeled_examples = ""
470
        else:
471
472
473
474
475
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
476
477
478
                    self._fewshot_docs = list(
                        self.validation_docs() if self.has_validation_docs() else self.test_docs()
                    )
479

480
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
481

482
483
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
484

485
            labeled_examples = "\n\n".join(
486
                [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]
487
            ) + "\n\n"
Leo Gao's avatar
Update  
Leo Gao committed
488

489
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
490
491
492
        return description + labeled_examples + example


493
class MultipleChoiceTask(Task, abc.ABC):
Leo Gao's avatar
Leo Gao committed
494
495
496
    def doc_to_target(self, doc):
        return " " + doc['choices'][doc['gold']]

Leo Gao's avatar
Leo Gao committed
497
498
499
500
501
502
503
504
505
506
507
    def construct_requests(self, doc, ctx):
        lls = [
            rf.loglikelihood(ctx, " {}".format(choice))[0]
            for choice in doc['choices']
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Leo Gao's avatar
Leo Gao committed
508
        acc = 1. if np.argmax(results) == gold else 0.
509
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Leo Gao's avatar
Leo Gao committed
510
        acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
Leo Gao's avatar
Leo Gao committed
511
512

        return {
Leo Gao's avatar
Leo Gao committed
513
514
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
515
516
517
518
        }
    
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
519
520
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
521
522
523
524
        }
    
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
525
526
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
527
528
529
        }


Jason Phang's avatar
Jason Phang committed
530
531
532
533
534
535
536
537
538
class PerplexityTask(Task, abc.ABC):

    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

Jonathan Tow's avatar
Merge  
Jonathan Tow committed
539
    def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
Jason Phang's avatar
Jason Phang committed
540
        assert num_fewshot == 0
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
541
        assert not provide_description, (
Jonathan Tow's avatar
Jonathan Tow committed
542
543
544
            "The `provide_description` arg will be removed in future versions. To prepend "
            "a custom description to the context, supply the corresponding string via the  "
            "`description` arg."
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
545
        )
Jason Phang's avatar
Jason Phang committed
546
547
548
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
549
550
551
552
553
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
554
555

    def doc_to_text(self, doc):
556
        return ""
Jason Phang's avatar
Jason Phang committed
557
558

    def doc_to_target(self, doc):
559
        return doc
Jason Phang's avatar
Jason Phang committed
560
561
562

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
563
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
564
565
566
567
        return req

    def process_results(self, doc, results):
        loglikelihood, = results
Leo Gao's avatar
Leo Gao committed
568
        words = self.count_words(doc)
569
        bytes_ = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
570
        return {
Leo Gao's avatar
Leo Gao committed
571
            "word_perplexity": (loglikelihood, words),
572
            "byte_perplexity": (loglikelihood, bytes_),
Leo Gao's avatar
Leo Gao committed
573
            "bits_per_byte": (-loglikelihood, self.count_bytes(doc))
Jason Phang's avatar
Jason Phang committed
574
575
576
577
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
578
579
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
Leo Gao's avatar
Leo Gao committed
580
            "bits_per_byte": weighted_mean
Jason Phang's avatar
Jason Phang committed
581
582
        }

583
584
    @classmethod
    def count_bytes(cls, doc):
Leo Gao's avatar
Leo Gao committed
585
        return len(doc.encode("utf-8"))
586
587
588

    @classmethod
    def count_words(cls, doc):
Leo Gao's avatar
Leo Gao committed
589
        """ Downstream tasks with custom word boundaries should override this! """
Leo Gao's avatar
Leo Gao committed
590
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
591

Jason Phang's avatar
Jason Phang committed
592

Leo Gao's avatar
Leo Gao committed
593
594
595
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
    return hashlib.sha256(dat.encode('utf-8')).hexdigest()
Leo Gao's avatar
Leo Gao committed
596
597


Leo Gao's avatar
Leo Gao committed
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
class CacheHook:
    def __init__(self, cachinglm):
        if cachinglm is None: 
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
    
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
613
614
class CachingLM:
    def __init__(self, lm, cache_db):
615
616
617
618
619
620
621
        """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.

        :param lm: LM
            Underlying LM
        :param cache_db: str
            Path to cache db
        """
Leo Gao's avatar
Leo Gao committed
622
623
        self.lm = lm
        self.cache_db = cache_db
624
625
        if os.path.dirname(cache_db):
            os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
626
627
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
628
629
630
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
631
632
633
634
635
636
637
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
            
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
638
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
639
640
641
642
643
644
645
646
647
648
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
            
649
            # actually run the LM on the requests that do not have cached results
Leo Gao's avatar
Leo Gao committed
650
651
652
653
654
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
655
656
                while res[resptr] is not None:
                    resptr += 1
Leo Gao's avatar
Leo Gao committed
657
658
659
660

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
661
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
662
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
663
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
664
665
666

            return res
        return fn
Leo Gao's avatar
Leo Gao committed
667
668
669
    
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
670

Jason Phang's avatar
Jason Phang committed
671

672
673
674
675
676
677
678
REQUEST_RETURN_LENGTHS = {
    'loglikelihood': 2,
    'greedy_until': None,
    'loglikelihood_rolling': None,
}


679
class Request:
Leo Gao's avatar
Leo Gao committed
680
681
682
    def __init__(self, request_type, args, index=None):
        if request_type not in REQUEST_RETURN_LENGTHS.keys():
            raise NotImplementedError('The request type {} is not implemented!'.format(request_type))
Leo Gao's avatar
Leo Gao committed
683

Leo Gao's avatar
Leo Gao committed
684
        self.request_type = request_type
685
686
687
688
        self.args = args
        self.index = index
    
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
689
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Leo Gao's avatar
Leo Gao committed
690
            raise IndexError('This request type does not return multiple arguments!')
Leo Gao's avatar
Leo Gao committed
691
692
        for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
            yield Request(self.request_type, self.args, i)
693
694
    
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
695
        if REQUEST_RETURN_LENGTHS[self.request_type] is None:
Leo Gao's avatar
Leo Gao committed
696
            raise IndexError('This request type does not return multiple arguments!')
Leo Gao's avatar
Leo Gao committed
697
        return Request(self.request_type, self.args, i)
Leo Gao's avatar
Leo Gao committed
698
699
    
    def __eq__(self, other):
Leo Gao's avatar
Leo Gao committed
700
        return self.request_type == other.request_type and self.args == other.args and self.index == other.index
Leo Gao's avatar
Leo Gao committed
701

Leo Gao's avatar
Leo Gao committed
702
    def __repr__(self):
Leo Gao's avatar
Leo Gao committed
703
        return f"Req_{self.request_type}{self.args}[{self.index}]\n"
704

Jason Phang's avatar
Jason Phang committed
705

Leo Gao's avatar
Leo Gao committed
706
707
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
708
709
        def fn(*args):
            return Request(attr, args)
Leo Gao's avatar
Leo Gao committed
710
711
712
713
        return fn


rf = RequestFactory()