base.py 14.2 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import abc
import random
thefazzer's avatar
thefazzer committed
3
import numpy as np
Leo Gao's avatar
Leo Gao committed
4
import re
&'s avatar
& committed
5

Leo Gao's avatar
Leo Gao committed
6
from lm_eval.metrics import mean, perplexity, weighted_mean
Jason Phang's avatar
gpt3  
Jason Phang committed
7

Jason Phang's avatar
Jason Phang committed
8

Leo Gao's avatar
Leo Gao committed
9
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
10
11
12
    def __init__(self):
        self.cache_hook = CacheHook(None)

Leo Gao's avatar
Leo Gao committed
13
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
14
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
15
16
17
        """Compute log-likelihood of generating a continuation from a context.
        Downstream tasks should attempt to use loglikelihood instead of other 
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
18

Leo Gao's avatar
Leo Gao committed
19
20
21
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Leo Gao's avatar
Leo Gao committed
22
23
                Context string. Implementations of LM must be able to handle an 
                empty context string.
Leo Gao's avatar
Leo Gao committed
24
25
26
27
28
29
30
            continuation: str
                The continuation over which log likelihood will be calculated. If 
                there is a word boundary, the space should be in the continuation. 
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
31
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
32
            isgreedy:
Jason Phang's avatar
Jason Phang committed
33
34
35
36
37
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
38
    def loglikelihood_rolling(self, requests):
Jason Phang's avatar
Jason Phang committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
76
77
78
        """
        pass

&'s avatar
& committed
79
    # TODO: Add an optional max length
Leo Gao's avatar
Leo Gao committed
80
    @abc.abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
81
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
82
83
84
85
86
87
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
88
89
90
            until: [str]
                The string sequences to generate until. These string sequences 
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
91
92
93
94
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
95
        """
Leo Gao's avatar
Leo Gao committed
96
97
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
98
99
100
101
102
103
104
105
106
107
108
    @classmethod
    def create_from_arg_string(cls, arg_string):
        """Constructor method, in case models need additional arguments
        e.g. OpenAI API engine, paths for loading, other params

        :param arg_string: str
            Left up to individual model class to handle

        """
        return cls()

Leo Gao's avatar
Leo Gao committed
109
110
111
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
112

113
class Task(abc.ABC):
&'s avatar
& committed
114
115
116
117
118
119
120
121
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Leo Gao's avatar
Leo Gao committed
122
123
    def __init__(self):
        self.download()
124
        self._training_docs = None
125
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
126
127
128
129
130

    def download(self):
        """Downloads the task dataset if necessary"""
        pass

131
132
    @abc.abstractmethod
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
133
        """Whether the task has a training set"""
134
        pass
135

136
137
    @abc.abstractmethod
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
138
139
140
141
142
143
        """Whether the task has a validation set"""
        pass

    @abc.abstractmethod
    def has_test_docs(self):
        """Whether the task has a test set"""
144
145
        pass

Leo Gao's avatar
Leo Gao committed
146
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
147
148
149
150
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
151
        return []
152

Leo Gao's avatar
Leo Gao committed
153
    def validation_docs(self):
154
155
156
157
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
158
        return []
159

Leo Gao's avatar
Leo Gao committed
160
    def test_docs(self):
161
162
163
164
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
165
        return []
Leo Gao's avatar
Leo Gao committed
166

167
    def fewshot_examples(self, k, rnd):
168
169
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
170

Leo Gao's avatar
Leo Gao committed
171
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
172
173

    @abc.abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
174
175
176
177
178
    def doc_to_text(self, doc):
        pass

    @abc.abstractmethod
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
179
        pass
Leo Gao's avatar
Leo Gao committed
180
181

    @abc.abstractmethod
182
    def construct_requests(self, doc, ctx):
Leo Gao's avatar
Leo Gao committed
183
184
185
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

186
187
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
188
        :param ctx: str
189
190
191
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
Leo Gao's avatar
Leo Gao committed
192
        """
Leo Gao's avatar
Leo Gao committed
193
        pass
194

Leo Gao's avatar
Leo Gao committed
195
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
196
    def process_results(self, doc, results):
Leo Gao's avatar
Update  
Leo Gao committed
197
        """Take a single document and the LM results and evaluates, returning a 
198
199
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
200
201
202
203
204

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
205
        """
Leo Gao's avatar
Leo Gao committed
206
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
207

208
209
210
    @abc.abstractmethod
    def aggregation(self):
        """
&'s avatar
& committed
211
        :returns: {str: [metric_score] -> float}
212
            A dictionary where keys are the names of submetrics and values are 
&'s avatar
& committed
213
            functions that aggregate a list of metric scores
214
215
216
217
218
219
220
221
222
223
224
225
        """
        pass

    @abc.abstractmethod
    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
226
    def fewshot_description(self):
Jason Phang's avatar
checkin  
Jason Phang committed
227
228
        return ""

229
    def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
Jason Phang's avatar
Jason Phang committed
230
        raw_description = self.fewshot_description()
Jason Phang's avatar
Jason Phang committed
231
        description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
232

233
234
        if num_fewshot == 0:
            labeled_examples = ""
235
        else:
236
237
238
239
240
241
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
                    self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs else self.test_docs())
242

243
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
244

245
246
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
247

248
            labeled_examples = "\n\n".join(
249
                [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]
250
            ) + "\n\n"
Leo Gao's avatar
Update  
Leo Gao committed
251

252
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
253
254
255
        return description + labeled_examples + example


Leo Gao's avatar
Leo Gao committed
256
class MultipleChoiceTask(Task):
Leo Gao's avatar
Leo Gao committed
257
258
259
    def doc_to_target(self, doc):
        return " " + doc['choices'][doc['gold']]

Leo Gao's avatar
Leo Gao committed
260
261
262
263
264
265
266
267
268
269
270
    def construct_requests(self, doc, ctx):
        lls = [
            rf.loglikelihood(ctx, " {}".format(choice))[0]
            for choice in doc['choices']
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Leo Gao's avatar
Leo Gao committed
271
        acc = 1. if np.argmax(results) == gold else 0.
272
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Leo Gao's avatar
Leo Gao committed
273
        acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
Leo Gao's avatar
Leo Gao committed
274
275

        return {
Leo Gao's avatar
Leo Gao committed
276
277
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
278
279
280
281
        }
    
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
282
283
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
284
285
286
287
        }
    
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
288
289
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
290
291
292
        }


Jason Phang's avatar
Jason Phang committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
class PerplexityTask(Task, abc.ABC):

    def has_training_docs(self):
        return False

    def fewshot_description(self):
        return ""

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

    def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
        assert num_fewshot == 0
        assert not provide_description
        return ""

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
311
312
313
314
315
        return {
            "word_perplexity": False,
            "byte_perplexity": False,
            "bits_per_byte": False,
        }
Jason Phang's avatar
Jason Phang committed
316
317
318
319
320
321
322
323
324

    def doc_to_text(self, doc):
        return doc

    def doc_to_target(self, doc):
        raise NotImplementedError()

    def construct_requests(self, doc, ctx):
        assert not ctx
Leo Gao's avatar
Leo Gao committed
325
        req = rf.loglikelihood_rolling(doc)
Jason Phang's avatar
Jason Phang committed
326
327
328
329
330
        return req

    def process_results(self, doc, results):
        loglikelihood, = results
        return {
Leo Gao's avatar
Leo Gao committed
331
332
333
            "word_perplexity": loglikelihood / self.count_words(self.doc_to_text(doc)),
            "byte_perplexity": loglikelihood / self.count_bytes(self.doc_to_text(doc)),
            "bits_per_byte": (-loglikelihood, self.count_bytes(self.doc_to_text(doc)))
Jason Phang's avatar
Jason Phang committed
334
335
336
337
        }

    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
338
339
340
            "word_perplexity": perplexity,
            "byte_perplexity": perplexity,
            "bits_per_byte": weighted_mean
Jason Phang's avatar
Jason Phang committed
341
342
        }

Leo Gao's avatar
Leo Gao committed
343
344
345
346
347
348
    def count_bytes(self, s):
        return len(s.encode("utf-8"))
    
    def count_words(self, s):
        """ Downstream tasks with custom word boundaries should override this! """
        return len(re.split(r"\s+", s))
Leo Gao's avatar
Leo Gao committed
349

Jason Phang's avatar
Jason Phang committed
350

351
req_ret_lens = {
Leo Gao's avatar
Leo Gao committed
352
    'loglikelihood': 2,
Leo Gao's avatar
Leo Gao committed
353
    'greedy_until': None,
Leo Gao's avatar
Leo Gao committed
354
    'loglikelihood_rolling': None,
355
356
}

Leo Gao's avatar
Leo Gao committed
357
358
359
360
361
import os
import json
import hashlib
from sqlitedict import SqliteDict

Leo Gao's avatar
Leo Gao committed
362
363
364
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
    return hashlib.sha256(dat.encode('utf-8')).hexdigest()
Leo Gao's avatar
Leo Gao committed
365
366


Leo Gao's avatar
Leo Gao committed
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
class CacheHook:
    def __init__(self, cachinglm):
        if cachinglm is None: 
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
    
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
382
383
384
385
386
387
388
class CachingLM:
    def __init__(self, lm, cache_db):
        self.lm = lm
        self.cache_db = cache_db
        os.makedirs(os.path.dirname(cache_db), exist_ok=True)
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
389
390
391
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
392
393
394
395
396
397
398
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
            
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
399
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
            
            # actually run the LM
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
                while res[resptr] is not None: resptr += 1

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
421
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
422
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
423
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
424
425
426

            return res
        return fn
Leo Gao's avatar
Leo Gao committed
427
428
429
    
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
430

Jason Phang's avatar
Jason Phang committed
431

432
433
434
435
class Request:
    def __init__(self, type, args, index=None):
        if type not in req_ret_lens.keys():
            raise NotImplementedError('The request type {} is not implemented!'.format(type))
Leo Gao's avatar
Leo Gao committed
436

437
438
439
440
441
        self.type = type
        self.args = args
        self.index = index
    
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
442
443
        if req_ret_lens[self.type] is None:
            raise IndexError('This request type does not return multiple arguments!')
444
445
446
447
448
        i = 0
        for i in range(req_ret_lens[self.type]):
            yield Request(self.type, self.args, i)
    
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
449
450
        if req_ret_lens[self.type] is None:
            raise IndexError('This request type does not return multiple arguments!')
451
        return Request(self.type, self.args, i)
Leo Gao's avatar
Leo Gao committed
452
453
454
    
    def __eq__(self, other):
        return self.type == other.type and self.args == other.args and self.index == other.index
Leo Gao's avatar
Leo Gao committed
455

Leo Gao's avatar
Leo Gao committed
456
457
    def __repr__(self):
        return f"Req_{self.type}{self.args}[{self.index}]\n"
Jason Phang's avatar
Jason Phang committed
458

Leo Gao's avatar
Leo Gao committed
459
460
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
461
462
        def fn(*args):
            return Request(attr, args)
Leo Gao's avatar
Leo Gao committed
463
464
465
466
        return fn


rf = RequestFactory()