base.py 14.8 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import abc
import random
thefazzer's avatar
thefazzer committed
3
import numpy as np
Leo Gao's avatar
Leo Gao committed
4
import re
&'s avatar
& committed
5

Leo Gao's avatar
Leo Gao committed
6
from lm_eval.metrics import mean, perplexity, weighted_perplexity, weighted_mean
Jason Phang's avatar
gpt3  
Jason Phang committed
7

Jason Phang's avatar
Jason Phang committed
8

Leo Gao's avatar
Leo Gao committed
9
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
10
11
12
    def __init__(self):
        self.cache_hook = CacheHook(None)

Leo Gao's avatar
Leo Gao committed
13
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
14
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
15
16
17
        """Compute log-likelihood of generating a continuation from a context.
        Downstream tasks should attempt to use loglikelihood instead of other 
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
18

Leo Gao's avatar
Leo Gao committed
19
20
21
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Leo Gao's avatar
Leo Gao committed
22
23
                Context string. Implementations of LM must be able to handle an 
                empty context string.
Leo Gao's avatar
Leo Gao committed
24
25
26
27
28
29
30
            continuation: str
                The continuation over which log likelihood will be calculated. If 
                there is a word boundary, the space should be in the continuation. 
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
31
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
32
            isgreedy:
Jason Phang's avatar
Jason Phang committed
33
34
35
36
37
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
38
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
76
77
78
        """
        pass

&'s avatar
& committed
79
    # TODO: Add an optional max length
Leo Gao's avatar
Leo Gao committed
80
    @abc.abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
81
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
82
83
84
85
86
87
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
88
89
90
            until: [str]
                The string sequences to generate until. These string sequences 
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
91
92
93
94
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
95
        """
Leo Gao's avatar
Leo Gao committed
96
97
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
98
99
100
101
102
103
104
105
106
107
108
    @classmethod
    def create_from_arg_string(cls, arg_string):
        """Constructor method, in case models need additional arguments
        e.g. OpenAI API engine, paths for loading, other params

        :param arg_string: str
            Left up to individual model class to handle

        """
        return cls()

Leo Gao's avatar
Leo Gao committed
109
110
111
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
112

113
class Task(abc.ABC):
&'s avatar
& committed
114
115
116
117
118
119
120
121
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Leo Gao's avatar
Leo Gao committed
122
123
    def __init__(self):
        self.download()
124
        self._training_docs = None
125
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
126
127
128
129
130

    def download(self):
        """Downloads the task dataset if necessary"""
        pass

131
132
133
134
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return False

135
136
    @abc.abstractmethod
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
137
        """Whether the task has a training set"""
138
        pass
139

140
141
    @abc.abstractmethod
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
142
143
144
145
146
147
        """Whether the task has a validation set"""
        pass

    @abc.abstractmethod
    def has_test_docs(self):
        """Whether the task has a test set"""
148
149
        pass

Leo Gao's avatar
Leo Gao committed
150
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
151
152
153
154
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
155
        return []
156

Leo Gao's avatar
Leo Gao committed
157
    def validation_docs(self):
158
159
160
161
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
162
        return []
163

Leo Gao's avatar
Leo Gao committed
164
    def test_docs(self):
165
166
167
168
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
169
        return []
Leo Gao's avatar
Leo Gao committed
170

171
    def fewshot_examples(self, k, rnd):
172
173
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
174

Leo Gao's avatar
Leo Gao committed
175
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
176

177
178
179
180
    def doc_to_decontamination_query(self, doc):
        print("Override doc_to_decontamination_query with document specific decontamination query.")
        assert(False)

Leo Gao's avatar
Leo Gao committed
181
    @abc.abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
182
183
184
185
186
    def doc_to_text(self, doc):
        pass

    @abc.abstractmethod
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
187
        pass
Leo Gao's avatar
Leo Gao committed
188
189

    @abc.abstractmethod
190
    def construct_requests(self, doc, ctx):
Leo Gao's avatar
Leo Gao committed
191
192
193
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

194
195
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
196
        :param ctx: str
197
198
199
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
Leo Gao's avatar
Leo Gao committed
200
        """
Leo Gao's avatar
Leo Gao committed
201
        pass
202

Leo Gao's avatar
Leo Gao committed
203
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
204
    def process_results(self, doc, results):
Leo Gao's avatar
Update  
Leo Gao committed
205
        """Take a single document and the LM results and evaluates, returning a 
206
207
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
208
209
210
211
212

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
213
        """
Leo Gao's avatar
Leo Gao committed
214
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
215

216
217
218
    @abc.abstractmethod
    def aggregation(self):
        """
&'s avatar
& committed
219
        :returns: {str: [metric_score] -> float}
220
            A dictionary where keys are the names of submetrics and values are 
&'s avatar
& committed
221
            functions that aggregate a list of metric scores
222
223
224
225
226
227
228
229
230
231
232
233
        """
        pass

    @abc.abstractmethod
    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
234
    def fewshot_description(self):
Jason Phang's avatar
checkin  
Jason Phang committed
235
236
        return ""

237
    def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
Jason Phang's avatar
Jason Phang committed
238
        raw_description = self.fewshot_description()
Jason Phang's avatar
Jason Phang committed
239
        description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
240

241
242
        if num_fewshot == 0:
            labeled_examples = ""
243
        else:
244
245
246
247
248
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
249
                    self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs() else self.test_docs())
250

251
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
252

253
254
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
255

256
            labeled_examples = "\n\n".join(
257
                [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]
258
            ) + "\n\n"
Leo Gao's avatar
Update  
Leo Gao committed
259

260
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
261
262
263
        return description + labeled_examples + example


Leo Gao's avatar
Leo Gao committed
264
class MultipleChoiceTask(Task):
Leo Gao's avatar
Leo Gao committed
265
266
267
    def doc_to_target(self, doc):
        return " " + doc['choices'][doc['gold']]

Leo Gao's avatar
Leo Gao committed
268
269
270
271
272
273
274
275
276
277
278
    def construct_requests(self, doc, ctx):
        lls = [
            rf.loglikelihood(ctx, " {}".format(choice))[0]
            for choice in doc['choices']
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Leo Gao's avatar
Leo Gao committed
279
        acc = 1. if np.argmax(results) == gold else 0.
280
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Leo Gao's avatar
Leo Gao committed
281
        acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
Leo Gao's avatar
Leo Gao committed
282
283

        return {
Leo Gao's avatar
Leo Gao committed
284
285
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
286
287
288
289
        }
    
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
290
291
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
292
293
294
295
        }
    
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
296
297
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
298
299
300
        }


Jason Phang's avatar
Jason Phang committed
301
302
class PerplexityTask(Task, abc.ABC):

303
304
305
306
    def should_decontaminate(self):
        """Whether this task supports decontamination against model training set."""
        return True

Jason Phang's avatar
Jason Phang committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
    def has_training_docs(self):
        return False

    def fewshot_description(self):
        return ""

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

    def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
        assert num_fewshot == 0
        assert not provide_description
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
323
324
325
326
327
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
328

329
330
331
    def doc_to_decontamination_query(self, doc):
        return doc

Jason Phang's avatar
Jason Phang committed
332
    def doc_to_text(self, doc):
333
        return ""
Jason Phang's avatar
Jason Phang committed
334
335

    def doc_to_target(self, doc):
336
        return doc
Jason Phang's avatar
Jason Phang committed
337
338
339

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
340
        req = rf.loglikelihood_rolling(self.doc_to_target(doc))
Jason Phang's avatar
Jason Phang committed
341
342
343
344
        return req

    def process_results(self, doc, results):
        loglikelihood, = results
Leo Gao's avatar
Leo Gao committed
345
346
        words = self.count_words(doc)
        bytes = self.count_bytes(doc)
Jason Phang's avatar
Jason Phang committed
347
        return {
Leo Gao's avatar
Leo Gao committed
348
349
            "word_perplexity": (loglikelihood, words),
            "byte_perplexity": (loglikelihood, bytes),
Leo Gao's avatar
Leo Gao committed
350
            "bits_per_byte": (-loglikelihood, self.count_bytes(doc))
Jason Phang's avatar
Jason Phang committed
351
352
353
354
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
355
356
            "word_perplexity": weighted_perplexity,
            "byte_perplexity": weighted_perplexity,
Leo Gao's avatar
Leo Gao committed
357
            "bits_per_byte": weighted_mean
Jason Phang's avatar
Jason Phang committed
358
359
        }

Leo Gao's avatar
Leo Gao committed
360
361
    def count_bytes(self, doc):
        return len(doc.encode("utf-8"))
Leo Gao's avatar
Leo Gao committed
362
    
Leo Gao's avatar
Leo Gao committed
363
    def count_words(self, doc):
Leo Gao's avatar
Leo Gao committed
364
        """ Downstream tasks with custom word boundaries should override this! """
Leo Gao's avatar
Leo Gao committed
365
        return len(re.split(r"\s+", doc))
Leo Gao's avatar
Leo Gao committed
366

Jason Phang's avatar
Jason Phang committed
367

368
req_ret_lens = {
Leo Gao's avatar
Leo Gao committed
369
    'loglikelihood': 2,
Leo Gao's avatar
Leo Gao committed
370
    'greedy_until': None,
Leo Gao's avatar
Leo Gao committed
371
    'loglikelihood_rolling': None,
372
373
}

Leo Gao's avatar
Leo Gao committed
374
375
376
377
378
import os
import json
import hashlib
from sqlitedict import SqliteDict

Leo Gao's avatar
Leo Gao committed
379
380
381
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
    return hashlib.sha256(dat.encode('utf-8')).hexdigest()
Leo Gao's avatar
Leo Gao committed
382
383


Leo Gao's avatar
Leo Gao committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
class CacheHook:
    def __init__(self, cachinglm):
        if cachinglm is None: 
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
    
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
399
400
401
402
class CachingLM:
    def __init__(self, lm, cache_db):
        self.lm = lm
        self.cache_db = cache_db
403
        if os.path.dirname(cache_db): os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Leo Gao's avatar
Leo Gao committed
404
405
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
406
407
408
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
409
410
411
412
413
414
415
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
            
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
416
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
            
            # actually run the LM
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
                while res[resptr] is not None: resptr += 1

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
438
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
439
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
440
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
441
442
443

            return res
        return fn
Leo Gao's avatar
Leo Gao committed
444
445
446
    
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
447

Jason Phang's avatar
Jason Phang committed
448

449
450
451
452
class Request:
    def __init__(self, type, args, index=None):
        if type not in req_ret_lens.keys():
            raise NotImplementedError('The request type {} is not implemented!'.format(type))
Leo Gao's avatar
Leo Gao committed
453

454
455
456
457
458
        self.type = type
        self.args = args
        self.index = index
    
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
459
460
        if req_ret_lens[self.type] is None:
            raise IndexError('This request type does not return multiple arguments!')
461
462
463
464
465
        i = 0
        for i in range(req_ret_lens[self.type]):
            yield Request(self.type, self.args, i)
    
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
466
467
        if req_ret_lens[self.type] is None:
            raise IndexError('This request type does not return multiple arguments!')
468
        return Request(self.type, self.args, i)
Leo Gao's avatar
Leo Gao committed
469
470
471
    
    def __eq__(self, other):
        return self.type == other.type and self.args == other.args and self.index == other.index
Leo Gao's avatar
Leo Gao committed
472

Leo Gao's avatar
Leo Gao committed
473
474
    def __repr__(self):
        return f"Req_{self.type}{self.args}[{self.index}]\n"
Jason Phang's avatar
Jason Phang committed
475

Leo Gao's avatar
Leo Gao committed
476
477
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
478
479
        def fn(*args):
            return Request(attr, args)
Leo Gao's avatar
Leo Gao committed
480
481
482
483
        return fn


rf = RequestFactory()