base.py 13.8 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import abc
import random
thefazzer's avatar
thefazzer committed
3
import numpy as np
&'s avatar
& committed
4
5

from lm_eval.metrics import mean
Jason Phang's avatar
gpt3  
Jason Phang committed
6

Jason Phang's avatar
Jason Phang committed
7

Leo Gao's avatar
Leo Gao committed
8
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
9
10
11
    def __init__(self):
        self.cache_hook = CacheHook(None)

Leo Gao's avatar
Leo Gao committed
12
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
13
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
14
15
16
        """Compute log-likelihood of generating a continuation from a context.
        Downstream tasks should attempt to use loglikelihood instead of other 
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
17

Leo Gao's avatar
Leo Gao committed
18
19
20
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Leo Gao's avatar
Leo Gao committed
21
22
                Context string. Implementations of LM must be able to handle an 
                empty context string.
Leo Gao's avatar
Leo Gao committed
23
24
25
26
27
28
29
            continuation: str
                The continuation over which log likelihood will be calculated. If 
                there is a word boundary, the space should be in the continuation. 
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
Jason Phang's avatar
Jason Phang committed
30
                The log probability of `continuation`
Leo Gao's avatar
Leo Gao committed
31
            isgreedy:
Jason Phang's avatar
Jason Phang committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
                Whether `continuation` would be generated by greedy sampling from `context`
        """
        pass

    @abc.abstractmethod
    def loglikelihood_perplexity(self, requests):
        """Compute full log-likelihood of a string, with no truncation, for perplexity computation
        - We will use the full max context length of the model.
        - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
        the max context length.
        - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
          which may simply concatenate multiple documents together.
        - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
          multiple chunks, the last input will still a full-sized context.
          Example:
            Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
            Prefix: EOT
            Max context length: 4
            Resulting input/prediction pairs:

                INPUT:  EOT   0   1   2
                PRED:     0   1   2   3

                INPUT:    3   4   5   6
                PRED:     4   5   6   7

                INPUT:    5   6   7   8
                PRED:             8   9

          Observe that:
            1. Each token is predicted exactly once
            2. For the last pair, we provide the full context, but only score the last two tokens

        :param requests: list
            A list of strings
            string: str
                String for which we are computing per-toke  loglikelihood
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `continuation`
            isgreedy:
                Whether `continuation` would be generated by greedy sampling from `context`
Leo Gao's avatar
Leo Gao committed
75
76
77
        """
        pass

&'s avatar
& committed
78
    # TODO: Add an optional max length
Leo Gao's avatar
Leo Gao committed
79
    @abc.abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
80
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
81
82
83
84
85
86
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
87
88
89
            until: [str]
                The string sequences to generate until. These string sequences 
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
90
91
92
93
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
94
        """
Leo Gao's avatar
Leo Gao committed
95
96
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
97
98
99
100
101
102
103
104
105
106
107
    @classmethod
    def create_from_arg_string(cls, arg_string):
        """Constructor method, in case models need additional arguments
        e.g. OpenAI API engine, paths for loading, other params

        :param arg_string: str
            Left up to individual model class to handle

        """
        return cls()

Leo Gao's avatar
Leo Gao committed
108
109
110
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
111

112
class Task(abc.ABC):
&'s avatar
& committed
113
114
115
116
117
118
119
120
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Leo Gao's avatar
Leo Gao committed
121
122
    def __init__(self):
        self.download()
123
        self._training_docs = None
124
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
125
126
127
128
129

    def download(self):
        """Downloads the task dataset if necessary"""
        pass

130
131
    @abc.abstractmethod
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
132
        """Whether the task has a training set"""
133
        pass
134

135
136
    @abc.abstractmethod
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
137
138
139
140
141
142
        """Whether the task has a validation set"""
        pass

    @abc.abstractmethod
    def has_test_docs(self):
        """Whether the task has a test set"""
143
144
        pass

Leo Gao's avatar
Leo Gao committed
145
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
146
147
148
149
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
150
        return []
151

Leo Gao's avatar
Leo Gao committed
152
    def validation_docs(self):
153
154
155
156
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
157
        return []
158

Leo Gao's avatar
Leo Gao committed
159
    def test_docs(self):
160
161
162
163
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
164
        return []
Leo Gao's avatar
Leo Gao committed
165

166
    def fewshot_examples(self, k, rnd):
167
168
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
169

Leo Gao's avatar
Leo Gao committed
170
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
171
172

    @abc.abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
173
174
175
176
177
    def doc_to_text(self, doc):
        pass

    @abc.abstractmethod
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
178
        pass
Leo Gao's avatar
Leo Gao committed
179
180

    @abc.abstractmethod
181
    def construct_requests(self, doc, ctx):
Leo Gao's avatar
Leo Gao committed
182
183
184
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

185
186
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
187
        :param ctx: str
188
189
190
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
Leo Gao's avatar
Leo Gao committed
191
        """
Leo Gao's avatar
Leo Gao committed
192
        pass
193

Leo Gao's avatar
Leo Gao committed
194
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
195
    def process_results(self, doc, results):
Leo Gao's avatar
Update  
Leo Gao committed
196
        """Take a single document and the LM results and evaluates, returning a 
197
198
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
199
200
201
202
203

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
204
        """
Leo Gao's avatar
Leo Gao committed
205
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
206

207
208
209
    @abc.abstractmethod
    def aggregation(self):
        """
&'s avatar
& committed
210
        :returns: {str: [metric_score] -> float}
211
            A dictionary where keys are the names of submetrics and values are 
&'s avatar
& committed
212
            functions that aggregate a list of metric scores
213
214
215
216
217
218
219
220
221
222
223
224
        """
        pass

    @abc.abstractmethod
    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
225
    def fewshot_description(self):
Jason Phang's avatar
checkin  
Jason Phang committed
226
227
        return ""

228
    def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
Jason Phang's avatar
Jason Phang committed
229
        raw_description = self.fewshot_description()
Jason Phang's avatar
Jason Phang committed
230
        description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
231

232
233
        if num_fewshot == 0:
            labeled_examples = ""
234
        else:
235
236
237
238
239
240
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
                    self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs else self.test_docs())
241

242
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
243

244
245
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
246

247
            labeled_examples = "\n\n".join(
248
                [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]
249
            ) + "\n\n"
Leo Gao's avatar
Update  
Leo Gao committed
250

251
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
252
253
254
        return description + labeled_examples + example


Leo Gao's avatar
Leo Gao committed
255
class MultipleChoiceTask(Task):
Leo Gao's avatar
Leo Gao committed
256
257
258
    def doc_to_target(self, doc):
        return " " + doc['choices'][doc['gold']]

Leo Gao's avatar
Leo Gao committed
259
260
261
262
263
264
265
266
267
268
269
    def construct_requests(self, doc, ctx):
        lls = [
            rf.loglikelihood(ctx, " {}".format(choice))[0]
            for choice in doc['choices']
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Leo Gao's avatar
Leo Gao committed
270
        acc = 1. if np.argmax(results) == gold else 0.
271
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Leo Gao's avatar
Leo Gao committed
272
        acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
Leo Gao's avatar
Leo Gao committed
273
274

        return {
Leo Gao's avatar
Leo Gao committed
275
276
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
277
278
279
280
        }
    
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
281
282
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
283
284
285
286
        }
    
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
287
288
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
289
290
291
        }


Jason Phang's avatar
Jason Phang committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
class PerplexityTask(Task, abc.ABC):

    def has_training_docs(self):
        return False

    def fewshot_description(self):
        return ""

    def fewshot_examples(self, k, rnd):
        assert k == 0
        return []

    def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
        assert num_fewshot == 0
        assert not provide_description
        return ""

    def higher_is_better(self):
        return False

    def doc_to_text(self, doc):
        return doc

    def doc_to_target(self, doc):
        raise NotImplementedError()
        return doc

    def construct_requests(self, doc, ctx):
        assert not ctx
        req = rf.loglikelihood_perplexity(doc)
        return req

    def process_results(self, doc, results):
        loglikelihood, = results
        return {
            "perplexity": loglikelihood,
        }

    def aggregation(self):
        return {
            "perplexity": self.compute_perplexity_from_loglikelihood,
        }

    @classmethod
    def compute_perplexity_from_loglikelihood(cls, loglikelihoods):
        aggregate_logprobs = np.concatenate(loglikelihoods)
        perplexity = np.exp(-aggregate_logprobs.mean())
        return float(perplexity)


342
req_ret_lens = {
Leo Gao's avatar
Leo Gao committed
343
    'loglikelihood': 2,
Leo Gao's avatar
Leo Gao committed
344
    'greedy_until': None,
Jason Phang's avatar
Jason Phang committed
345
    'loglikelihood_perplexity': 1,
346
347
}

Leo Gao's avatar
Leo Gao committed
348
349
350
351
352
import os
import json
import hashlib
from sqlitedict import SqliteDict

Leo Gao's avatar
Leo Gao committed
353
354
355
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
    return hashlib.sha256(dat.encode('utf-8')).hexdigest()
Leo Gao's avatar
Leo Gao committed
356
357


Leo Gao's avatar
Leo Gao committed
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
class CacheHook:
    def __init__(self, cachinglm):
        if cachinglm is None: 
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
    
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
373
374
375
376
377
378
379
class CachingLM:
    def __init__(self, lm, cache_db):
        self.lm = lm
        self.cache_db = cache_db
        os.makedirs(os.path.dirname(cache_db), exist_ok=True)
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
380
381
382
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
383
384
385
386
387
388
389
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
            
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
390
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
            
            # actually run the LM
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
                while res[resptr] is not None: resptr += 1

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
412
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
413
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
414
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
415
416
417

            return res
        return fn
Leo Gao's avatar
Leo Gao committed
418
419
420
    
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
421

Jason Phang's avatar
Jason Phang committed
422

423
424
425
426
class Request:
    def __init__(self, type, args, index=None):
        if type not in req_ret_lens.keys():
            raise NotImplementedError('The request type {} is not implemented!'.format(type))
Leo Gao's avatar
Leo Gao committed
427

428
429
430
431
432
        self.type = type
        self.args = args
        self.index = index
    
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
433
434
        if req_ret_lens[self.type] is None:
            raise IndexError('This request type does not return multiple arguments!')
435
436
437
438
439
        i = 0
        for i in range(req_ret_lens[self.type]):
            yield Request(self.type, self.args, i)
    
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
440
441
        if req_ret_lens[self.type] is None:
            raise IndexError('This request type does not return multiple arguments!')
442
        return Request(self.type, self.args, i)
Leo Gao's avatar
Leo Gao committed
443
444
445
    
    def __eq__(self, other):
        return self.type == other.type and self.args == other.args and self.index == other.index
Leo Gao's avatar
Leo Gao committed
446

Leo Gao's avatar
Leo Gao committed
447
448
    def __repr__(self):
        return f"Req_{self.type}{self.args}[{self.index}]\n"
Jason Phang's avatar
Jason Phang committed
449

Leo Gao's avatar
Leo Gao committed
450
451
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
452
453
        def fn(*args):
            return Request(attr, args)
Leo Gao's avatar
Leo Gao committed
454
455
456
457
        return fn


rf = RequestFactory()