base.py 14.4 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import abc
import random
thefazzer's avatar
thefazzer committed
3
import numpy as np
Leo Gao's avatar
Leo Gao committed
4
import re
5
from lm_eval import tasks
&'s avatar
& committed
6

Leo Gao's avatar
Leo Gao committed
7
from lm_eval.metrics import mean, perplexity, weighted_perplexity, weighted_mean
Jason Phang's avatar
gpt3  
Jason Phang committed
8

Jason Phang's avatar
Jason Phang committed
9

Leo Gao's avatar
Leo Gao committed
10
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
11
12
13
    def __init__(self):
        self.cache_hook = CacheHook(None)

Leo Gao's avatar
Leo Gao committed
14
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
15
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
16
17
18
        """Compute log-likelihood of generating a continuation from a context.
        Downstream tasks should attempt to use loglikelihood instead of other 
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
19

Leo Gao's avatar
Leo Gao committed
20
21
22
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Leo Gao's avatar
Leo Gao committed
23
24
                Context string. Implementations of LM must be able to handle an 
                empty context string.
Leo Gao's avatar
Leo Gao committed
25
26
27
28
29
30
31
            continuation: str
                The continuation over which log likelihood will be calculated. If 
                there is a word boundary, the space should be in the continuation. 
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
32
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
33
            isgreedy:
Jason Phang's avatar
Jason Phang committed
34
35
36
37
38
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
39
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
77
78
79
        """
        pass

&'s avatar
& committed
80
    # TODO: Add an optional max length
Leo Gao's avatar
Leo Gao committed
81
    @abc.abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
82
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
83
84
85
86
87
88
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
89
90
91
            until: [str]
                The string sequences to generate until. These string sequences 
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
92
93
94
95
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
96
        """
Leo Gao's avatar
Leo Gao committed
97
98
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
99
100
101
102
103
104
105
106
107
108
109
    @classmethod
    def create_from_arg_string(cls, arg_string):
        """Constructor method, in case models need additional arguments
        e.g. OpenAI API engine, paths for loading, other params

        :param arg_string: str
            Left up to individual model class to handle

        """
        return cls()

Leo Gao's avatar
Leo Gao committed
110
111
112
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
113

114
class Task(abc.ABC):
&'s avatar
& committed
115
116
117
118
119
120
121
122
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Leo Gao's avatar
Leo Gao committed
123
124
    def __init__(self):
        self.download()
125
        self._training_docs = None
126
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
127
128
129
130
131

    def download(self):
        """Downloads the task dataset if necessary"""
        pass

132
133
    @abc.abstractmethod
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
134
        """Whether the task has a training set"""
135
        pass
136

137
138
    @abc.abstractmethod
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
139
140
141
142
143
144
        """Whether the task has a validation set"""
        pass

    @abc.abstractmethod
    def has_test_docs(self):
        """Whether the task has a test set"""
145
146
        pass

Leo Gao's avatar
Leo Gao committed
147
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
148
149
150
151
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
152
        return []
153

Leo Gao's avatar
Leo Gao committed
154
    def validation_docs(self):
155
156
157
158
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
159
        return []
160

Leo Gao's avatar
Leo Gao committed
161
    def test_docs(self):
162
163
164
165
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
166
        return []
Leo Gao's avatar
Leo Gao committed
167

168
    def fewshot_examples(self, k, rnd):
169
170
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
171

Leo Gao's avatar
Leo Gao committed
172
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
173
174

    @abc.abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
175
176
177
178
179
    def doc_to_text(self, doc):
        pass

    @abc.abstractmethod
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
180
        pass
Leo Gao's avatar
Leo Gao committed
181
182

    @abc.abstractmethod
183
    def construct_requests(self, doc, ctx):
Leo Gao's avatar
Leo Gao committed
184
185
186
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

187
188
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
189
        :param ctx: str
190
191
192
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
Leo Gao's avatar
Leo Gao committed
193
        """
Leo Gao's avatar
Leo Gao committed
194
        pass
195

Leo Gao's avatar
Leo Gao committed
196
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
197
    def process_results(self, doc, results):
Leo Gao's avatar
Update  
Leo Gao committed
198
        """Take a single document and the LM results and evaluates, returning a 
199
200
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
201
202
203
204
205

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
206
        """
Leo Gao's avatar
Leo Gao committed
207
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
208

209
210
211
    @abc.abstractmethod
    def aggregation(self):
        """
&'s avatar
& committed
212
        :returns: {str: [metric_score] -> float}
213
            A dictionary where keys are the names of submetrics and values are 
&'s avatar
& committed
214
            functions that aggregate a list of metric scores
215
216
217
218
219
220
221
222
223
224
225
226
        """
        pass

    @abc.abstractmethod
    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
227
    def fewshot_description(self):
228
229
230
231
232
        import warnings
        warnings.warn(
            "`fewshot_description` will be removed in coming versions. Pass " \
            "any custom descriptions to the `evaluate` function instead.",
            DeprecationWarning)
Jason Phang's avatar
checkin  
Jason Phang committed
233
234
        return ""

235
236
    def fewshot_context(self, doc, num_fewshot, rnd, description=None):
        description = description + "\n\n" if description else ""
237

238
239
        if num_fewshot == 0:
            labeled_examples = ""
240
        else:
241
242
243
244
245
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
246
                    self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs() else self.test_docs())
247

248
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
249

250
251
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
252

253
            labeled_examples = "\n\n".join(
254
                [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]
255
            ) + "\n\n"
Leo Gao's avatar
Update  
Leo Gao committed
256

257
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
258
259
260
        return description + labeled_examples + example


Leo Gao's avatar
Leo Gao committed
261
class MultipleChoiceTask(Task):
Leo Gao's avatar
Leo Gao committed
262
263
264
    def doc_to_target(self, doc):
        return " " + doc['choices'][doc['gold']]

Leo Gao's avatar
Leo Gao committed
265
266
267
268
269
270
271
272
273
274
275
    def construct_requests(self, doc, ctx):
        lls = [
            rf.loglikelihood(ctx, " {}".format(choice))[0]
            for choice in doc['choices']
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Leo Gao's avatar
Leo Gao committed
276
        acc = 1. if np.argmax(results) == gold else 0.
277
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Leo Gao's avatar
Leo Gao committed
278
        acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
Leo Gao's avatar
Leo Gao committed
279
280

        return {
Leo Gao's avatar
Leo Gao committed
281
282
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
283
284
285
286
        }
    
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
287
288
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
289
290
291
292
        }
    
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
293
294
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
295
296
297
        }


Jason Phang's avatar
Jason Phang committed
298
299
300
301
302
303
304
305
306
class PerplexityTask(Task, abc.ABC):

    def has_training_docs(self):
        return False

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

307
    def fewshot_context(self, doc, num_fewshot, rnd, description=None):
Jason Phang's avatar
Jason Phang committed
308
        assert num_fewshot == 0
309
        assert description is None 
Jason Phang's avatar
Jason Phang committed
310
311
312
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
313
314
315
316
317
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
318
319

    def doc_to_text(self, doc):
320
        return ""
Jason Phang's avatar
Jason Phang committed
321
322

    def doc_to_target(self, doc):
323
        return doc
Jason Phang's avatar
Jason Phang committed
324
325
326

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
327
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
328
329
330
331
        return req

    def process_results(self, doc, results):
        loglikelihood, = results
Leo Gao's avatar
Leo Gao committed
332
333
        words = self.count_words(doc)
        bytes = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
334
        return {
Leo Gao's avatar
Leo Gao committed
335
336
            "word_perplexity": (loglikelihood, words),
            "byte_perplexity": (loglikelihood, bytes),
Leo Gao's avatar
Leo Gao committed
337
            "bits_per_byte": (-loglikelihood, self.count_bytes(doc))
Jason Phang's avatar
Jason Phang committed
338
339
340
341
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
342
343
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
Leo Gao's avatar
Leo Gao committed
344
            "bits_per_byte": weighted_mean
Jason Phang's avatar
Jason Phang committed
345
346
        }

Leo Gao's avatar
Leo Gao committed
347
348
    def count_bytes(self, doc):
        return len(doc.encode("utf-8"))
Leo Gao's avatar
Leo Gao committed
349
    
Leo Gao's avatar
Leo Gao committed
350
    def count_words(self, doc):
Leo Gao's avatar
Leo Gao committed
351
        """ Downstream tasks with custom word boundaries should override this! """
Leo Gao's avatar
Leo Gao committed
352
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
353

Jason Phang's avatar
Jason Phang committed
354

355
req_ret_lens = {
Leo Gao's avatar
Leo Gao committed
356
    'loglikelihood': 2,
Leo Gao's avatar
Leo Gao committed
357
    'greedy_until': None,
Leo Gao's avatar
Leo Gao committed
358
    'loglikelihood_rolling': None,
359
360
}

Leo Gao's avatar
Leo Gao committed
361
362
363
364
365
import os
import json
import hashlib
from sqlitedict import SqliteDict

Leo Gao's avatar
Leo Gao committed
366
367
368
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
    return hashlib.sha256(dat.encode('utf-8')).hexdigest()
Leo Gao's avatar
Leo Gao committed
369
370


Leo Gao's avatar
Leo Gao committed
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
class CacheHook:
    def __init__(self, cachinglm):
        if cachinglm is None: 
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
    
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
386
387
388
389
class CachingLM:
    def __init__(self, lm, cache_db):
        self.lm = lm
        self.cache_db = cache_db
390
        if os.path.dirname(cache_db): os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
391
392
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
393
394
395
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
396
397
398
399
400
401
402
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
            
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
403
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
            
            # actually run the LM
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
                while res[resptr] is not None: resptr += 1

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
425
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
426
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
427
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
428
429
430

            return res
        return fn
Leo Gao's avatar
Leo Gao committed
431
432
433
    
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
434

Jason Phang's avatar
Jason Phang committed
435

436
437
438
439
class Request:
    def __init__(self, type, args, index=None):
        if type not in req_ret_lens.keys():
            raise NotImplementedError('The request type {} is not implemented!'.format(type))
Leo Gao's avatar
Leo Gao committed
440

441
442
443
444
445
        self.type = type
        self.args = args
        self.index = index
    
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
446
447
        if req_ret_lens[self.type] is None:
            raise IndexError('This request type does not return multiple arguments!')
448
449
450
451
452
        i = 0
        for i in range(req_ret_lens[self.type]):
            yield Request(self.type, self.args, i)
    
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
453
454
        if req_ret_lens[self.type] is None:
            raise IndexError('This request type does not return multiple arguments!')
455
        return Request(self.type, self.args, i)
Leo Gao's avatar
Leo Gao committed
456
457
458
    
    def __eq__(self, other):
        return self.type == other.type and self.args == other.args and self.index == other.index
Leo Gao's avatar
Leo Gao committed
459

Leo Gao's avatar
Leo Gao committed
460
461
    def __repr__(self):
        return f"Req_{self.type}{self.args}[{self.index}]\n"
Jason Phang's avatar
Jason Phang committed
462

Leo Gao's avatar
Leo Gao committed
463
464
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
465
466
        def fn(*args):
            return Request(attr, args)
Leo Gao's avatar
Leo Gao committed
467
468
469
470
        return fn


rf = RequestFactory()