openai.py 7.19 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
import time
Jason Phang's avatar
gpt3  
Jason Phang committed
3
import transformers
lintangsutawika's avatar
lintangsutawika committed
4
5
6

import numpy as np

Leo Gao's avatar
Leo Gao committed
7
from tqdm import tqdm
lintangsutawika's avatar
lintangsutawika committed
8
from lm_eval import utils
9
10
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
Leo Gao's avatar
Leo Gao committed
11
12
13


def get_result(response, ctxlen):
14
15
16
17
18
19
20
21
22
23
24
25
    """Process results from OpenAI API response.

    :param response: dict
        OpenAI API Response
    :param ctxlen: int
        Length of context (so we can slice them away and only keep the predictions)
    :return:
        continuation_logprobs: np.array
            Log probabilities of continuation tokens
        is_greedy: bool
            whether argmax matches given continuation exactly
    """
Leo Gao's avatar
Leo Gao committed
26
27
28
29
30
31
32
33
34
35
36
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
Fabrizio Milo's avatar
Fabrizio Milo committed
37

Leo Gao's avatar
Leo Gao committed
38
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
39
40


Leo Gao's avatar
Leo Gao committed
41
def oa_completion(**kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
42
    """Query OpenAI API for completion.
Leo Gao's avatar
Leo Gao committed
43

44
45
46
    Retry with back-off until they respond
    """
    import openai
Fabrizio Milo's avatar
Fabrizio Milo committed
47

Leo Gao's avatar
Leo Gao committed
48
49
50
51
52
    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
Leo Gao's avatar
Leo Gao committed
53
            import traceback
Fabrizio Milo's avatar
Fabrizio Milo committed
54

Leo Gao's avatar
Leo Gao committed
55
            traceback.print_exc()
Leo Gao's avatar
Leo Gao committed
56
57
58
59
            time.sleep(backoff_time)
            backoff_time *= 1.5


60
@register_model("openai"., "gooseai")
61
class GPT3LM(LM):
Leo Gao's avatar
Leo Gao committed
62
    REQ_CHUNK_SIZE = 20
Jason Phang's avatar
Jason Phang committed
63
64
65
66
67
68
69
70
71

    def __init__(self, engine, truncate=False):
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
72
        super().__init__()
73

Jason Phang's avatar
Jason Phang committed
74
        import openai
Fabrizio Milo's avatar
Fabrizio Milo committed
75

Jason Phang's avatar
gpt3  
Jason Phang committed
76
        self.engine = engine
Fabrizio Milo's avatar
Fabrizio Milo committed
77
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
Leo Gao's avatar
Leo Gao committed
78

79
        self.vocab_size = self.tokenizer.vocab_size
Leo Gao's avatar
Leo Gao committed
80

Leo Gao's avatar
Leo Gao committed
81
82
        # to make the annoying "Using pad_token, but it is not set yet." error go away
        self.tokenizer.pad_token = "<|endoftext|>"
Fabrizio Milo's avatar
Fabrizio Milo committed
83
        assert self.tokenizer.encode("hello\n\nhello") == [31373, 198, 198, 31373]
Jason Phang's avatar
Jason Phang committed
84
        self.truncate = truncate
Fabrizio Milo's avatar
Fabrizio Milo committed
85
86
87
        self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(
            ["<|endoftext|>"]
        )[0]
Jason Phang's avatar
Jason Phang committed
88

Jason Phang's avatar
gpt3  
Jason Phang committed
89
90
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

    @property
    def eot_token_id(self):
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

115
116
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
117

118
119
    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)
Leo Gao's avatar
Leo Gao committed
120

121
    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
Leo Gao's avatar
Leo Gao committed
122
123
        res = []

124
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
125
126
127
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
            # we care about and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
128
            toks = x[1] + x[2]
129
            return -len(toks), tuple(toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
130

Fabrizio Milo's avatar
Fabrizio Milo committed
131
        re_ord = utils.Reorderer(requests, _collate)
Jason Phang's avatar
Jason Phang committed
132

Fabrizio Milo's avatar
Fabrizio Milo committed
133
        for chunk in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
134
            list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
Fabrizio Milo's avatar
Fabrizio Milo committed
135
136
            disable=disable_tqdm,
        ):
Leo Gao's avatar
Leo Gao committed
137
138
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
139
            for cache_key, context_enc, continuation_enc in chunk:
140
                # max_length+1 because the API takes up to 2049 tokens, including the first context token
Fabrizio Milo's avatar
Fabrizio Milo committed
141
                inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
142
                # TODO: the logic is much simpler if we just look at the length of continuation tokens
Fabrizio Milo's avatar
Fabrizio Milo committed
143
144
145
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
                )
Leo Gao's avatar
Leo Gao committed
146
147
148
149

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
150
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
151
152
153
                engine=self.engine,
                prompt=inps,
                echo=True,
Fabrizio Milo's avatar
Fabrizio Milo committed
154
155
                max_tokens=0,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
156
157
158
                logprobs=10,
            )

Fabrizio Milo's avatar
Fabrizio Milo committed
159
160
161
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
                response.choices, ctxlens, chunk
            ):
Leo Gao's avatar
Leo Gao committed
162
163
164
165
166
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
167
168
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Jason Phang's avatar
Jason Phang committed
169

Fabrizio Milo's avatar
Fabrizio Milo committed
170
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
171
172

    def greedy_until(self, requests):
173
174
        if not requests:
            return []
Leo Gao's avatar
Leo Gao committed
175
176
        res = []

177
        def _collate(x):
178
            toks = self.tok_encode(x[0])
179
            return len(toks), x[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
180

Fabrizio Milo's avatar
Fabrizio Milo committed
181
        re_ord = utils.Reorderer(requests, _collate)
182

Leo Gao's avatar
Leo Gao committed
183
184
185
186
187
188
189
190
191
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
Fabrizio Milo's avatar
Fabrizio Milo committed
192

193
194
            if ret:
                yield ret, lastuntil
Leo Gao's avatar
Leo Gao committed
195

196
        # todo: more intelligent batching for heterogeneous `until`
Fabrizio Milo's avatar
Fabrizio Milo committed
197
        for chunk, until in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
198
            list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
Fabrizio Milo's avatar
Fabrizio Milo committed
199
        ):
Leo Gao's avatar
Leo Gao committed
200
201
            inps = []
            for context, _ in chunk:
202
                context_enc = self.tok_encode(context)
Fabrizio Milo's avatar
Fabrizio Milo committed
203
                inp = context_enc[-(self.max_length - self.max_gen_toks) :]
Leo Gao's avatar
Leo Gao committed
204
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
205

Leo Gao's avatar
Leo Gao committed
206
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
207
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
208
                prompt=inps,
Fabrizio Milo's avatar
Fabrizio Milo committed
209
210
                max_tokens=self.max_gen_toks,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
211
                logprobs=10,
212
                stop=until,
Leo Gao's avatar
Leo Gao committed
213
            )
Leo Gao's avatar
Leo Gao committed
214

215
            for resp, (context, until_) in zip(response.choices, chunk):
Fabrizio Milo's avatar
Fabrizio Milo committed
216
                s = resp["text"]
Leo Gao's avatar
Leo Gao committed
217

218
                for term in until_:
Leo Gao's avatar
Leo Gao committed
219
                    s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
220

Leo Gao's avatar
Leo Gao committed
221
                # partial caching
222
                self.cache_hook.add_partial("greedy_until", (context, until_), s)
Fabrizio Milo's avatar
Fabrizio Milo committed
223

Leo Gao's avatar
Leo Gao committed
224
                res.append(s)
Fabrizio Milo's avatar
Fabrizio Milo committed
225

Fabrizio Milo's avatar
Fabrizio Milo committed
226
        return re_ord.get_original(res)
227
228
229
230
231
232
233
234

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Isn't used because we override greedy_until
        raise NotImplementedError()