base.py 21.8 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import abc
import random
3
from typing import Iterable
thefazzer's avatar
thefazzer committed
4
import numpy as np
Leo Gao's avatar
Leo Gao committed
5
import re
6
from tqdm import tqdm
7
import torch
Leo Gao's avatar
Leo Gao committed
8
9
import torch.nn as nn
import torch.nn.functional as F
&'s avatar
& committed
10

Leo Gao's avatar
Leo Gao committed
11
from lm_eval.metrics import mean, perplexity, weighted_perplexity, weighted_mean
12
from lm_eval import utils
13
from abc import abstractmethod
Jason Phang's avatar
Jason Phang committed
14

Leo Gao's avatar
Leo Gao committed
15
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
16
17
18
    def __init__(self):
        self.cache_hook = CacheHook(None)

19
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
20
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
21
22
23
        """Compute log-likelihood of generating a continuation from a context.
        Downstream tasks should attempt to use loglikelihood instead of other 
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
24

Leo Gao's avatar
Leo Gao committed
25
26
27
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Leo Gao's avatar
Leo Gao committed
28
29
                Context string. Implementations of LM must be able to handle an 
                empty context string.
Leo Gao's avatar
Leo Gao committed
30
31
32
33
34
35
36
            continuation: str
                The continuation over which log likelihood will be calculated. If 
                there is a word boundary, the space should be in the continuation. 
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
37
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
38
            isgreedy:
Jason Phang's avatar
Jason Phang committed
39
40
41
42
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

43
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
44
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
82
83
84
        """
        pass

&'s avatar
& committed
85
    # TODO: Add an optional max length
86
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
87
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
88
89
90
91
92
93
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
94
95
96
            until: [str]
                The string sequences to generate until. These string sequences 
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
97
98
99
100
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
101
        """
Leo Gao's avatar
Leo Gao committed
102
103
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
104
    @classmethod
105
106
107
108
    def create_from_arg_string(cls, arg_string, additional_config={}):
        args = utils.simple_parse_args_string(arg_string)
        args2 = {k: v for k, v in additional_config.items() if v is not None}
        return cls(**args, **args2)
Jason Phang's avatar
gpt3  
Jason Phang committed
109

Leo Gao's avatar
Leo Gao committed
110
111
112
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
113

114
115
class BaseLM(LM):
    @abstractmethod
116
117
    def tok_encode(self, string: str): pass
    
118
    @abstractmethod
119
120
    def tok_decode(self, tokens: Iterable[int]): pass

121
122
123
124
125
126
127
128
129
130
131
132
133
    @abstractmethod
    def _model_generate(self, context, max_length, eos_token_id): pass

    @abstractmethod
    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
        logits retuned from the model
        """
        pass
134

Leo Gao's avatar
Leo Gao committed
135
    # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    # TODO: enforce this somehow

    def loglikelihood(self, requests):
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [self.eot_token_id]
            else:
                context_enc = self.tok_encode(context)

            continuation_enc = self.tok_encode(continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

    def loglikelihood_rolling(self, requests):
        # TODO: Implement caching once we've confirmed the perplexity implementation
        # TODO: automatic batch size detection for vectorization

        loglikelihoods = []
        for string, in tqdm(requests):
            rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
                token_list=self.tok_encode(string),
                prefix_token=self.eot_token_id,
                max_seq_len=self.max_length,
                context_len=1,
            )))

            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
            string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
            
            # discard is_greedy
            string_nll = [x[0] for x in string_nll]
            
            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
            #   this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

            toks = x[1] + x[2]
            return (-len(toks), tuple(toks))
        
        # TODO: automatic (variable) batch size detection for vectorization
        reord = utils.Reorderer(requests, _collate)
        for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
            inps = []
            contlens = []
            inplens = []

            padding_length = None

            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works:
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
                # gpt2    \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.vocab_size] slice
                # cont_toks      4 5 6 7 8 9

                # when too long to fit in context, truncate from the left
                inp = torch.tensor(
                    (context_enc + continuation_enc)[-(self.max_length+1):][:-1]
                , dtype=torch.long).to(self.device)
                inplen, = inp.shape

                cont = continuation_enc

                # since in _collate we make sure length is descending, the longest is always the first one.
                padding_length = padding_length if padding_length is not None else inplen

                # pad to length
                inp = torch.cat([
                    inp, # [seq]
                    torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
                ], dim=0)

                inps.append(inp.unsqueeze(0))
                contlens.append(cont)
                inplens.append(inplen)

            multi_logits = F.log_softmax(self._model_call(torch.cat(inps, dim=0)), dim=-1).cpu()  # [batch, seq, vocab]

            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
                contlen = len(cont_toks)

                logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]

                greedy_tokens = logits.argmax(dim=-1)

                # cont_toks :: [1, seq]
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)

                max_equal = (greedy_tokens == cont_toks).all()

                #last_token_slice = logits[:, -1, :].squeeze(0).tolist()

                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]

                answer = (float(logits.sum()), bool(max_equal))

                # partial caching
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)

        return reord.get_original(res)
    
    def greedy_until(self, requests):
        # TODO: implement fully general `until` that handles untils that are 
        # multiple tokens or that span multiple tokens correctly

        # TODO: extract to TokenizedLM?
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

        for context, until in tqdm(reord.get_reordered()):
            if isinstance(until, str): until = [until]

            primary_until, = self.tok_encode(until[0])
            
            context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device)

            cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until)

            s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:])

            for term in until:
                s = s.split(term)[0]
            
            # partial caching
            self.cache_hook.add_partial("greedy_until", (context, until), s)
            
            res.append(s)
        
        return reord.get_original(res)

302

303
class Task(abc.ABC):
&'s avatar
&amp; committed
304
305
306
307
308
309
310
311
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Leo Gao's avatar
Leo Gao committed
312
313
    def __init__(self):
        self.download()
314
        self._training_docs = None
315
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
316
317
318
319
320

    def download(self):
        """Downloads the task dataset if necessary"""
        pass

321
    @abstractmethod
322
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
323
        """Whether the task has a training set"""
324
        pass
325

326
    @abstractmethod
327
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
328
329
330
        """Whether the task has a validation set"""
        pass

331
    @abstractmethod
Jason Phang's avatar
checkin  
Jason Phang committed
332
333
    def has_test_docs(self):
        """Whether the task has a test set"""
334
335
        pass

Leo Gao's avatar
Leo Gao committed
336
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
337
338
339
340
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
341
        return []
342

Leo Gao's avatar
Leo Gao committed
343
    def validation_docs(self):
344
345
346
347
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
348
        return []
349

Leo Gao's avatar
Leo Gao committed
350
    def test_docs(self):
351
352
353
354
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
355
        return []
Leo Gao's avatar
Leo Gao committed
356

357
    def fewshot_examples(self, k, rnd):
358
359
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
360

Leo Gao's avatar
Leo Gao committed
361
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
362

363
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
364
365
366
    def doc_to_text(self, doc):
        pass

367
    @abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
368
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
369
        pass
Leo Gao's avatar
Leo Gao committed
370

371
    @abstractmethod
372
    def construct_requests(self, doc, ctx):
Leo Gao's avatar
Leo Gao committed
373
374
375
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

376
377
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
378
        :param ctx: str
379
380
381
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
Leo Gao's avatar
Leo Gao committed
382
        """
Leo Gao's avatar
Leo Gao committed
383
        pass
384

385
    @abstractmethod
Leo Gao's avatar
Leo Gao committed
386
    def process_results(self, doc, results):
Leo Gao's avatar
Update  
Leo Gao committed
387
        """Take a single document and the LM results and evaluates, returning a 
388
389
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
390
391
392
393
394

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
395
        """
Leo Gao's avatar
Leo Gao committed
396
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
397

398
    @abstractmethod
399
400
    def aggregation(self):
        """
&'s avatar
&amp; committed
401
        :returns: {str: [metric_score] -> float}
402
            A dictionary where keys are the names of submetrics and values are 
&'s avatar
&amp; committed
403
            functions that aggregate a list of metric scores
404
405
406
        """
        pass

407
    @abstractmethod
408
409
410
411
412
413
414
415
    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
416
    def fewshot_description(self):
Jason Phang's avatar
checkin  
Jason Phang committed
417
418
        return ""

419
    def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
Jason Phang's avatar
Jason Phang committed
420
        raw_description = self.fewshot_description()
Jason Phang's avatar
Jason Phang committed
421
        description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
422

423
424
        if num_fewshot == 0:
            labeled_examples = ""
425
        else:
426
427
428
429
430
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
431
                    self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs() else self.test_docs())
432

433
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
434

435
436
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
437

438
            labeled_examples = "\n\n".join(
439
                [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]
440
            ) + "\n\n"
Leo Gao's avatar
Update  
Leo Gao committed
441

442
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
443
444
445
        return description + labeled_examples + example


Leo Gao's avatar
Leo Gao committed
446
class MultipleChoiceTask(Task):
Leo Gao's avatar
Leo Gao committed
447
448
449
    def doc_to_target(self, doc):
        return " " + doc['choices'][doc['gold']]

Leo Gao's avatar
Leo Gao committed
450
451
452
453
454
455
456
457
458
459
460
    def construct_requests(self, doc, ctx):
        lls = [
            rf.loglikelihood(ctx, " {}".format(choice))[0]
            for choice in doc['choices']
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Leo Gao's avatar
Leo Gao committed
461
        acc = 1. if np.argmax(results) == gold else 0.
462
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Leo Gao's avatar
Leo Gao committed
463
        acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
Leo Gao's avatar
Leo Gao committed
464
465

        return {
Leo Gao's avatar
Leo Gao committed
466
467
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
468
469
470
471
        }
    
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
472
473
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
474
475
476
477
        }
    
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
478
479
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
480
481
482
        }


Jason Phang's avatar
Jason Phang committed
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
class PerplexityTask(Task, abc.ABC):

    def has_training_docs(self):
        return False

    def fewshot_description(self):
        return ""

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

    def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
        assert num_fewshot == 0
        assert not provide_description
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
501
502
503
504
505
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
506
507

    def doc_to_text(self, doc):
508
        return ""
Jason Phang's avatar
Jason Phang committed
509
510

    def doc_to_target(self, doc):
511
        return doc
Jason Phang's avatar
Jason Phang committed
512
513
514

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
515
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
516
517
518
519
        return req

    def process_results(self, doc, results):
        loglikelihood, = results
Leo Gao's avatar
Leo Gao committed
520
521
        words = self.count_words(doc)
        bytes = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
522
        return {
Leo Gao's avatar
Leo Gao committed
523
524
            "word_perplexity": (loglikelihood, words),
            "byte_perplexity": (loglikelihood, bytes),
Leo Gao's avatar
Leo Gao committed
525
            "bits_per_byte": (-loglikelihood, self.count_bytes(doc))
Jason Phang's avatar
Jason Phang committed
526
527
528
529
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
530
531
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
Leo Gao's avatar
Leo Gao committed
532
            "bits_per_byte": weighted_mean
Jason Phang's avatar
Jason Phang committed
533
534
        }

Leo Gao's avatar
Leo Gao committed
535
536
    def count_bytes(self, doc):
        return len(doc.encode("utf-8"))
Leo Gao's avatar
Leo Gao committed
537
    
Leo Gao's avatar
Leo Gao committed
538
    def count_words(self, doc):
Leo Gao's avatar
Leo Gao committed
539
        """ Downstream tasks with custom word boundaries should override this! """
Leo Gao's avatar
Leo Gao committed
540
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
541

Jason Phang's avatar
Jason Phang committed
542

543
req_ret_lens = {
Leo Gao's avatar
Leo Gao committed
544
    'loglikelihood': 2,
Leo Gao's avatar
Leo Gao committed
545
    'greedy_until': None,
Leo Gao's avatar
Leo Gao committed
546
    'loglikelihood_rolling': None,
547
548
}

Leo Gao's avatar
Leo Gao committed
549
550
551
552
553
import os
import json
import hashlib
from sqlitedict import SqliteDict

Leo Gao's avatar
Leo Gao committed
554
555
556
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
    return hashlib.sha256(dat.encode('utf-8')).hexdigest()
Leo Gao's avatar
Leo Gao committed
557
558


Leo Gao's avatar
Leo Gao committed
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
class CacheHook:
    def __init__(self, cachinglm):
        if cachinglm is None: 
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
    
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
574
575
576
577
class CachingLM:
    def __init__(self, lm, cache_db):
        self.lm = lm
        self.cache_db = cache_db
578
        if os.path.dirname(cache_db): os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
579
580
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
581
582
583
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
584
585
586
587
588
589
590
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
            
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
591
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
            
            # actually run the LM
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
                while res[resptr] is not None: resptr += 1

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
613
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
614
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
615
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
616
617
618

            return res
        return fn
Leo Gao's avatar
Leo Gao committed
619
620
621
    
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
622

Jason Phang's avatar
Jason Phang committed
623

624
625
626
627
class Request:
    def __init__(self, type, args, index=None):
        if type not in req_ret_lens.keys():
            raise NotImplementedError('The request type {} is not implemented!'.format(type))
Leo Gao's avatar
Leo Gao committed
628

629
630
631
632
633
        self.type = type
        self.args = args
        self.index = index
    
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
634
635
        if req_ret_lens[self.type] is None:
            raise IndexError('This request type does not return multiple arguments!')
636
637
638
639
640
        i = 0
        for i in range(req_ret_lens[self.type]):
            yield Request(self.type, self.args, i)
    
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
641
642
        if req_ret_lens[self.type] is None:
            raise IndexError('This request type does not return multiple arguments!')
643
        return Request(self.type, self.args, i)
Leo Gao's avatar
Leo Gao committed
644
645
646
    
    def __eq__(self, other):
        return self.type == other.type and self.args == other.args and self.index == other.index
Leo Gao's avatar
Leo Gao committed
647

Leo Gao's avatar
Leo Gao committed
648
649
    def __repr__(self):
        return f"Req_{self.type}{self.args}[{self.index}]\n"
Jason Phang's avatar
Jason Phang committed
650

Leo Gao's avatar
Leo Gao committed
651
652
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
653
654
        def fn(*args):
            return Request(attr, args)
Leo Gao's avatar
Leo Gao committed
655
656
657
658
        return fn


rf = RequestFactory()