gpt3.py 5.7 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import os
import transformers
Jason Phang's avatar
lib  
Jason Phang committed
3
4
from lm_eval.base import LM
from lm_eval import utils
Leo Gao's avatar
Leo Gao committed
5
from tqdm import tqdm
Leo Gao's avatar
Leo Gao committed
6
import time
Leo Gao's avatar
Leo Gao committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22


def get_result(response, ctxlen):
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
    
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
23
24


Leo Gao's avatar
Leo Gao committed
25
26
27
28
29
30
31
32
33
34
35
36
def oa_completion(**kwargs):
    import openai

    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
            time.sleep(backoff_time)
            backoff_time *= 1.5


Jason Phang's avatar
gpt3  
Jason Phang committed
37
class GPT3LM(LM):
Jason Phang's avatar
Jason Phang committed
38
39

    MAX_LENGTH = 2048
Leo Gao's avatar
Leo Gao committed
40
    REQ_CHUNK_SIZE = 20
Leo Gao's avatar
Leo Gao committed
41
    MAX_GEN_TOKS = 256
Jason Phang's avatar
Jason Phang committed
42
43
44
45
46
47
48
49
50

    def __init__(self, engine, truncate=False):
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
51
        super().__init__()
Jason Phang's avatar
Jason Phang committed
52
        import openai
Jason Phang's avatar
gpt3  
Jason Phang committed
53
        self.engine = engine
54
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Leo Gao's avatar
Leo Gao committed
55

Leo Gao's avatar
Leo Gao committed
56

Leo Gao's avatar
Leo Gao committed
57
58
        # to make the annoying "Using pad_token, but it is not set yet." error go away
        self.tokenizer.pad_token = "<|endoftext|>"
Leo Gao's avatar
Leo Gao committed
59
        assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
Jason Phang's avatar
Jason Phang committed
60
61
        self.truncate = truncate

Jason Phang's avatar
gpt3  
Jason Phang committed
62
63
64
65
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]

    @classmethod
Jason Phang's avatar
lib  
Jason Phang committed
66
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
gpt3  
Jason Phang committed
67
68
69
        args = utils.simple_parse_args_string(arg_string)
        return cls(engine=args.get("engine", "davinci"))

Leo Gao's avatar
Leo Gao committed
70
    def loglikelihood(self, requests):
Leo Gao's avatar
Leo Gao committed
71
72
73
74
75
76
77
78
79
80
        new_reqs = []
        for context, continuation in requests:
            if context == "":
                # end of text as context
                context_enc = [50256]
            else:
                context_enc = self.tokenizer.encode(context)

            continuation_enc = self.tokenizer.encode(continuation)

Leo Gao's avatar
Leo Gao committed
81
            new_reqs.append(((context, continuation), context_enc, continuation_enc))
Leo Gao's avatar
Leo Gao committed
82
83
84
85

        return self._loglikelihood_tokens(new_reqs)

    def _loglikelihood_tokens(self, requests):
Leo Gao's avatar
Leo Gao committed
86
        import openai
Leo Gao's avatar
Leo Gao committed
87
88
        res = []

89
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
90
91
92
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
            # we care about and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
93
94
            toks = x[1] + x[2]
            return (len(toks), tuple(toks))
Jason Phang's avatar
Jason Phang committed
95

96
        reord = utils.Reorderer(requests, _collate)
Jason Phang's avatar
Jason Phang committed
97

98
        for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
Leo Gao's avatar
Leo Gao committed
99
100
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
101
            for cache_key, context_enc, continuation_enc in chunk:
Leo Gao's avatar
Leo Gao committed
102
103
104
105
106
107
                inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
                ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH)

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
108
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
109
110
111
112
113
114
115
                engine=self.engine,
                prompt=inps,
                echo=True,
                max_tokens=0, temperature=0.,
                logprobs=10,
            )

Leo Gao's avatar
Leo Gao committed
116
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk):
Leo Gao's avatar
Leo Gao committed
117
118
119
120
121
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
122
123
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Jason Phang's avatar
Jason Phang committed
124

125
        return reord.get_original(res)
Leo Gao's avatar
Leo Gao committed
126
127

    def greedy_until(self, requests):
Leo Gao's avatar
Leo Gao committed
128
        if not requests: return []
Leo Gao's avatar
Leo Gao committed
129
130
131
        import openai
        res = []

132
133
134
135
136
137
        def _collate(x):
            toks = self.tokenizer.encode(x[0])
            return (len(toks), x[0])
        
        reord = utils.Reorderer(requests, _collate)

Leo Gao's avatar
Leo Gao committed
138
139
140
141
142
143
144
145
146
147
148
149
150
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
            
            if ret: yield ret, lastuntil

        # todo: more intelligent batching for heterogenous `until`
151
        for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
Leo Gao's avatar
Leo Gao committed
152
153
154
155
156
            inps = []
            for context, _ in chunk:
                context_enc = self.tokenizer.encode(context)
                inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):]
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
157

Leo Gao's avatar
Leo Gao committed
158
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
159
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
160
                prompt=inps,
Leo Gao's avatar
Leo Gao committed
161
162
163
                max_tokens=self.MAX_GEN_TOKS, 
                temperature=0.,
                logprobs=10,
Leo Gao's avatar
Leo Gao committed
164
                stop=until
Leo Gao's avatar
Leo Gao committed
165
            )
Leo Gao's avatar
Leo Gao committed
166

Leo Gao's avatar
Leo Gao committed
167
            for resp, (context, until) in zip(response.choices, chunk):
Leo Gao's avatar
Leo Gao committed
168
                s = resp['text']
Leo Gao's avatar
Leo Gao committed
169
170
171

                for term in until:
                    s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
172

Leo Gao's avatar
Leo Gao committed
173
174
175
                # partial caching
                self.cache_hook.add_partial("greedy_until", (context, until), s)
                
Leo Gao's avatar
Leo Gao committed
176
                res.append(s)
Leo Gao's avatar
Leo Gao committed
177
        
Leo Gao's avatar
Leo Gao committed
178
        return reord.get_original(res)()