base.py 10.8 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import abc
import random
thefazzer's avatar
thefazzer committed
3
import numpy as np
&'s avatar
& committed
4
5

from lm_eval.metrics import mean
Jason Phang's avatar
gpt3  
Jason Phang committed
6

Jason Phang's avatar
Jason Phang committed
7

Leo Gao's avatar
Leo Gao committed
8
class LM(abc.ABC):
Leo Gao's avatar
Leo Gao committed
9
10
11
    def __init__(self):
        self.cache_hook = CacheHook(None)

Leo Gao's avatar
Leo Gao committed
12
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
13
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
14
15
16
        """Compute log-likelihood of generating a continuation from a context.
        Downstream tasks should attempt to use loglikelihood instead of other 
        LM calls whenever possible.
Jason Phang's avatar
gpt3  
Jason Phang committed
17

Leo Gao's avatar
Leo Gao committed
18
19
20
        :param requests: list
            A list of pairs (context, continuation)
            context: str
Leo Gao's avatar
Leo Gao committed
21
22
                Context string. Implementations of LM must be able to handle an 
                empty context string.
Leo Gao's avatar
Leo Gao committed
23
24
25
26
27
28
29
30
31
32
33
34
35
            continuation: str
                The continuation over which log likelihood will be calculated. If 
                there is a word boundary, the space should be in the continuation. 
                For example, context="hello" continuation=" world" is correct.
        :return: list
            A list of pairs (logprob, isgreedy)
            logprob: float
                The log probability of `contination`
            isgreedy:
                Whether `contination` would be generated by greedy sampling from `context`
        """
        pass

&'s avatar
& committed
36
    # TODO: Add an optional max length
Leo Gao's avatar
Leo Gao committed
37
    @abc.abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
38
    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
39
40
41
42
43
44
        """Generate greedily until a stopping sequence

        :param requests: list
            A list of pairs (context, until)
            context: str
                Context string
Leo Gao's avatar
Leo Gao committed
45
46
47
            until: [str]
                The string sequences to generate until. These string sequences 
                may each span across multiple tokens, or may be part of one token.
Leo Gao's avatar
Leo Gao committed
48
49
50
51
        :return: list
            A list of strings continuation
            continuation: str
                The generated continuation.
Jason Phang's avatar
gpt3  
Jason Phang committed
52
        """
Leo Gao's avatar
Leo Gao committed
53
54
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
55
56
57
58
59
60
61
62
63
64
65
    @classmethod
    def create_from_arg_string(cls, arg_string):
        """Constructor method, in case models need additional arguments
        e.g. OpenAI API engine, paths for loading, other params

        :param arg_string: str
            Left up to individual model class to handle

        """
        return cls()

Leo Gao's avatar
Leo Gao committed
66
67
68
    def set_cache_hook(self, cache_hook):
        self.cache_hook = cache_hook

Leo Gao's avatar
Leo Gao committed
69

70
class Task(abc.ABC):
&'s avatar
& committed
71
72
73
74
75
76
77
78
    """A task represents an entire benchmark including its dataset, problems,
    answers, and evaluation methods. See BoolQ for a simple example implementation

    A `doc` can be any python object which represents one instance of evaluation.
    This is usually a dictionary e.g.
        {"question": ..., "answer": ...} or
        {"question": ..., question, answer)
    """
Leo Gao's avatar
Leo Gao committed
79
80
    def __init__(self):
        self.download()
81
        self._training_docs = None
82
        self._fewshot_docs = None
sdtblck's avatar
sdtblck committed
83
84
85
86
87

    def download(self):
        """Downloads the task dataset if necessary"""
        pass

88
89
    @abc.abstractmethod
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
90
        """Whether the task has a training set"""
91
        pass
92

93
94
    @abc.abstractmethod
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
95
96
97
98
99
100
        """Whether the task has a validation set"""
        pass

    @abc.abstractmethod
    def has_test_docs(self):
        """Whether the task has a test set"""
101
102
        pass

Leo Gao's avatar
Leo Gao committed
103
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
104
105
106
107
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
108
        return []
109

Leo Gao's avatar
Leo Gao committed
110
    def validation_docs(self):
111
112
113
114
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
115
        return []
116

Leo Gao's avatar
Leo Gao committed
117
    def test_docs(self):
118
119
120
121
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
122
        return []
Leo Gao's avatar
Leo Gao committed
123

124
    def fewshot_examples(self, k, rnd):
125
126
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())
127

Leo Gao's avatar
Leo Gao committed
128
        return rnd.sample(self._training_docs, k)
Leo Gao's avatar
Leo Gao committed
129
130

    @abc.abstractmethod
Leo Gao's avatar
Update  
Leo Gao committed
131
132
133
134
135
    def doc_to_text(self, doc):
        pass

    @abc.abstractmethod
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
136
        pass
Leo Gao's avatar
Leo Gao committed
137
138

    @abc.abstractmethod
139
    def construct_requests(self, doc, ctx):
Leo Gao's avatar
Leo Gao committed
140
141
142
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

143
144
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
145
        :param ctx: str
146
147
148
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
Leo Gao's avatar
Leo Gao committed
149
        """
Leo Gao's avatar
Leo Gao committed
150
        pass
151

Leo Gao's avatar
Leo Gao committed
152
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
153
    def process_results(self, doc, results):
Leo Gao's avatar
Update  
Leo Gao committed
154
        """Take a single document and the LM results and evaluates, returning a 
155
156
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document
Leo Gao's avatar
Leo Gao committed
157
158
159
160
161

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
Jason Phang's avatar
checkin  
Jason Phang committed
162
        """
Leo Gao's avatar
Leo Gao committed
163
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
164

165
166
167
    @abc.abstractmethod
    def aggregation(self):
        """
&'s avatar
& committed
168
        :returns: {str: [metric_score] -> float}
169
            A dictionary where keys are the names of submetrics and values are 
&'s avatar
& committed
170
            functions that aggregate a list of metric scores
171
172
173
174
175
176
177
178
179
180
181
182
        """
        pass

    @abc.abstractmethod
    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
        pass

Jason Phang's avatar
Jason Phang committed
183
    def fewshot_description(self):
Jason Phang's avatar
checkin  
Jason Phang committed
184
185
        return ""

186
    def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
Jason Phang's avatar
Jason Phang committed
187
        raw_description = self.fewshot_description()
Jason Phang's avatar
Jason Phang committed
188
        description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
189

190
191
        if num_fewshot == 0:
            labeled_examples = ""
192
        else:
193
194
195
196
197
198
            # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
            if self.has_training_docs():
                fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
            else:
                if self._fewshot_docs is None:
                    self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs else self.test_docs())
199

200
                fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
201

202
203
                # get rid of the doc that's the one we're evaluating, if it's in the fewshot
                fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
204

205
            labeled_examples = "\n\n".join(
206
                [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]
207
            ) + "\n\n"
Leo Gao's avatar
Update  
Leo Gao committed
208

209
        example = self.doc_to_text(doc)
Leo Gao's avatar
Leo Gao committed
210
211
212
        return description + labeled_examples + example


Leo Gao's avatar
Leo Gao committed
213
class MultipleChoiceTask(Task):
Leo Gao's avatar
Leo Gao committed
214
215
216
    def doc_to_target(self, doc):
        return " " + doc['choices'][doc['gold']]

Leo Gao's avatar
Leo Gao committed
217
218
219
220
221
222
223
224
225
226
227
    def construct_requests(self, doc, ctx):
        lls = [
            rf.loglikelihood(ctx, " {}".format(choice))[0]
            for choice in doc['choices']
        ]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]

Leo Gao's avatar
Leo Gao committed
228
        acc = 1. if np.argmax(results) == gold else 0.
229
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
Leo Gao's avatar
Leo Gao committed
230
        acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
Leo Gao's avatar
Leo Gao committed
231
232

        return {
Leo Gao's avatar
Leo Gao committed
233
234
            "acc": acc,
            "acc_norm": acc_norm,
Leo Gao's avatar
Leo Gao committed
235
236
237
238
        }
    
    def higher_is_better(self):
        return {
Leo Gao's avatar
Leo Gao committed
239
240
            "acc": True,
            "acc_norm": True,
Leo Gao's avatar
Leo Gao committed
241
242
243
244
        }
    
    def aggregation(self):
        return {
Leo Gao's avatar
Leo Gao committed
245
246
            "acc": mean,
            "acc_norm": mean,
Leo Gao's avatar
Leo Gao committed
247
248
249
        }


250
req_ret_lens = {
Leo Gao's avatar
Leo Gao committed
251
    'loglikelihood': 2,
Leo Gao's avatar
Leo Gao committed
252
    'greedy_until': None,
253
254
}

Leo Gao's avatar
Leo Gao committed
255
256
257
258
259
import os
import json
import hashlib
from sqlitedict import SqliteDict

Leo Gao's avatar
Leo Gao committed
260
261
262
def hash_args(attr, args):
    dat = json.dumps([attr] + list(args))
    return hashlib.sha256(dat.encode('utf-8')).hexdigest()
Leo Gao's avatar
Leo Gao committed
263
264


Leo Gao's avatar
Leo Gao committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
class CacheHook:
    def __init__(self, cachinglm):
        if cachinglm is None: 
            self.dbdict = None
            return

        self.dbdict = cachinglm.dbdict
    
    def add_partial(self, attr, req, res):
        if self.dbdict is None:
            return
        hsh = hash_args(attr, req)
        self.dbdict[hsh] = res


Leo Gao's avatar
Leo Gao committed
280
281
282
283
284
285
286
class CachingLM:
    def __init__(self, lm, cache_db):
        self.lm = lm
        self.cache_db = cache_db
        os.makedirs(os.path.dirname(cache_db), exist_ok=True)
        self.dbdict = SqliteDict(cache_db, autocommit=True)

Leo Gao's avatar
Leo Gao committed
287
288
289
        # add hook to lm
        lm.set_cache_hook(self.get_cache_hook())

Leo Gao's avatar
Leo Gao committed
290
291
292
293
294
295
296
    def __getattr__(self, attr):
        def fn(requests):
            res = []
            remaining_reqs = []
            
            # figure out which ones are cached and which ones are new
            for req in requests:
Leo Gao's avatar
Leo Gao committed
297
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
                if hsh in self.dbdict:
                    ob = self.dbdict[hsh]

                    assert ob is not None

                    res.append(ob)
                else:
                    res.append(None)
                    remaining_reqs.append(req)
            
            # actually run the LM
            rem_res = getattr(self.lm, attr)(remaining_reqs)

            # stick the new ones back into the list and also cache any of the new ones
            resptr = 0
            for req, r in zip(remaining_reqs, rem_res):
                while res[resptr] is not None: resptr += 1

                res[resptr] = r

                # caching
Leo Gao's avatar
Leo Gao committed
319
                hsh = hash_args(attr, req)
Leo Gao's avatar
Leo Gao committed
320
                self.dbdict[hsh] = r
Leo Gao's avatar
Leo Gao committed
321
            self.dbdict.commit()
Leo Gao's avatar
Leo Gao committed
322
323
324

            return res
        return fn
Leo Gao's avatar
Leo Gao committed
325
326
327
    
    def get_cache_hook(self):
        return CacheHook(self)
Leo Gao's avatar
Leo Gao committed
328

Jason Phang's avatar
Jason Phang committed
329

330
331
332
333
class Request:
    def __init__(self, type, args, index=None):
        if type not in req_ret_lens.keys():
            raise NotImplementedError('The request type {} is not implemented!'.format(type))
Leo Gao's avatar
Leo Gao committed
334

335
336
337
338
339
        self.type = type
        self.args = args
        self.index = index
    
    def __iter__(self):
Leo Gao's avatar
Leo Gao committed
340
341
        if req_ret_lens[self.type] is None:
            raise IndexError('This request type does not return multiple arguments!')
342
343
344
345
346
        i = 0
        for i in range(req_ret_lens[self.type]):
            yield Request(self.type, self.args, i)
    
    def __getitem__(self, i):
Leo Gao's avatar
Leo Gao committed
347
348
        if req_ret_lens[self.type] is None:
            raise IndexError('This request type does not return multiple arguments!')
349
        return Request(self.type, self.args, i)
Leo Gao's avatar
Leo Gao committed
350
351
352
    
    def __eq__(self, other):
        return self.type == other.type and self.args == other.args and self.index == other.index
Leo Gao's avatar
Leo Gao committed
353

Leo Gao's avatar
Leo Gao committed
354
355
    def __repr__(self):
        return f"Req_{self.type}{self.args}[{self.index}]\n"
Jason Phang's avatar
Jason Phang committed
356

Leo Gao's avatar
Leo Gao committed
357
358
class RequestFactory:
    def __getattr__(self, attr):
Leo Gao's avatar
Update  
Leo Gao committed
359
360
        def fn(*args):
            return Request(attr, args)
Leo Gao's avatar
Leo Gao committed
361
362
363
364
        return fn


rf = RequestFactory()