cost_estimate.py 2.25 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import random
import transformers
3
4
from lm_eval import tasks, evaluator
from lm_eval.base import LM
Leo Gao's avatar
Leo Gao committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35


class DryrunLM(LM):
    def __init__(self):
        self.tokencost = 0
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
        self.tokenizer.pad_token = "<|endoftext|>"

    @classmethod
    def create_from_arg_string(cls, arg_string):
        return cls()

    def loglikelihood(self, requests):
        res = []
        
        for ctx, cont in requests:
            res.append((-random.random(), False))
            self.tokencost += len(self.tokenizer.tokenize(ctx + cont))

        return res
    
    def greedy_until(self, requests):
        res = []
        
        for ctx, until in requests:
            res.append("lol")

            # assume worst case - generates until 256
            self.tokencost += len(self.tokenizer.tokenize(ctx)) + 256

        return res
Leo Gao's avatar
Leo Gao committed
36
37
38
39
40
41
42
43
44
    
    def loglikelihood_rolling(self, requests):
        res = []
        
        for s, in requests:
            # assume worst case: extra full context
            self.tokencost += len(self.tokenizer.tokenize(s)) + 2048

        return res
Leo Gao's avatar
Leo Gao committed
45
46
47
48


def main():
    lm = DryrunLM()
Leo Gao's avatar
Leo Gao committed
49
50
    
    task_list = "arc_challenge,arc_easy,boolq,cola,copa,headqa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,record,rte,sciq,sst,triviaqa,webqs,wic,wikitext,winogrande,wnli,wsc"
Leo Gao's avatar
Leo Gao committed
51
    values = []
Leo Gao's avatar
Leo Gao committed
52
    for taskname in task_list.split(","):
Leo Gao's avatar
Leo Gao committed
53
        lm.tokencost = 0
54
        evaluator.evaluate(lm, {taskname: tasks.get_task(taskname)()}, False, 0, None, bootstrap_iters=10)
Leo Gao's avatar
Leo Gao committed
55
56

        print(taskname, lm.tokencost)
Leo Gao's avatar
Leo Gao committed
57
        values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06])
Leo Gao's avatar
Leo Gao committed
58
59
60
    from pytablewriter import MarkdownTableWriter

    writer = MarkdownTableWriter()
Leo Gao's avatar
Leo Gao committed
61
    writer.headers = ["Task", "Tokens", "Ada", "Babbage", "Curie", "Davinci"]
Leo Gao's avatar
Leo Gao committed
62
63
64

    values.sort(key=lambda x: -x[1])
    totcost = sum([x[1] for x in values])
Leo Gao's avatar
Leo Gao committed
65
    values.append(["**Total**", totcost, totcost / 1000 * 0.0008, totcost / 1000 * 0.0012, totcost / 1000 * 0.006, totcost / 1000 * 0.06])
Leo Gao's avatar
Leo Gao committed
66
67
68
69
70
71

    writer.value_matrix = values

    print(writer.dumps())
if __name__ == "__main__":
    main()