openai_completions.py 10.1 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
import time
baberabb's avatar
baberabb committed
3
4
import transformers  # type: ignore
from typing import List, Tuple
Leo Gao's avatar
Leo Gao committed
5
from tqdm import tqdm
lintangsutawika's avatar
lintangsutawika committed
6
from lm_eval import utils
7
8
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
Leo Gao's avatar
Leo Gao committed
9
10


baberabb's avatar
baberabb committed
11
def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]:
12
13
14
15
16
17
18
19
20
21
22
23
    """Process results from OpenAI API response.

    :param response: dict
        OpenAI API Response
    :param ctxlen: int
        Length of context (so we can slice them away and only keep the predictions)
    :return:
        continuation_logprobs: np.array
            Log probabilities of continuation tokens
        is_greedy: bool
            whether argmax matches given continuation exactly
    """
Leo Gao's avatar
Leo Gao committed
24
25
26
27
28
29
30
31
32
33
34
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
Fabrizio Milo's avatar
Fabrizio Milo committed
35

Leo Gao's avatar
Leo Gao committed
36
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
37
38


Leo Gao's avatar
Leo Gao committed
39
def oa_completion(**kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
40
    """Query OpenAI API for completion.
Leo Gao's avatar
Leo Gao committed
41

42
43
44
    Retry with back-off until they respond
    """
    import openai
Fabrizio Milo's avatar
Fabrizio Milo committed
45

Leo Gao's avatar
Leo Gao committed
46
47
48
49
50
    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
Leo Gao's avatar
Leo Gao committed
51
            import traceback
Fabrizio Milo's avatar
Fabrizio Milo committed
52

Leo Gao's avatar
Leo Gao committed
53
            traceback.print_exc()
Leo Gao's avatar
Leo Gao committed
54
55
56
57
            time.sleep(backoff_time)
            backoff_time *= 1.5


haileyschoelkopf's avatar
haileyschoelkopf committed
58
@register_model("openai", "openai-completions", "gooseai")
haileyschoelkopf's avatar
haileyschoelkopf committed
59
class OpenaiCompletionsLM(LM):
Leo Gao's avatar
Leo Gao committed
60
    REQ_CHUNK_SIZE = 20
Jason Phang's avatar
Jason Phang committed
61

baberabb's avatar
baberabb committed
62
63
64
65
66
67
    def __init__(
        self,
        engine: str = "text-davinci-003",
        truncate: bool = False,
        batch_size: int = 1,
    ):
Jason Phang's avatar
Jason Phang committed
68
69
70
71
72
73
74
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
75
        super().__init__()
76

Jason Phang's avatar
Jason Phang committed
77
        import openai
Fabrizio Milo's avatar
Fabrizio Milo committed
78

Jason Phang's avatar
gpt3  
Jason Phang committed
79
        self.engine = engine
Fabrizio Milo's avatar
Fabrizio Milo committed
80
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
Leo Gao's avatar
Leo Gao committed
81

82
        self.vocab_size = self.tokenizer.vocab_size
Leo Gao's avatar
Leo Gao committed
83

Leo Gao's avatar
Leo Gao committed
84
85
        # to make the annoying "Using pad_token, but it is not set yet." error go away
        self.tokenizer.pad_token = "<|endoftext|>"
Fabrizio Milo's avatar
Fabrizio Milo committed
86
        assert self.tokenizer.encode("hello\n\nhello") == [31373, 198, 198, 31373]
Jason Phang's avatar
Jason Phang committed
87
        self.truncate = truncate
Fabrizio Milo's avatar
Fabrizio Milo committed
88
89
90
        self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(
            ["<|endoftext|>"]
        )[0]
Jason Phang's avatar
Jason Phang committed
91

Jason Phang's avatar
gpt3  
Jason Phang committed
92
93
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

    @property
    def eot_token_id(self):
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

baberabb's avatar
baberabb committed
118
    def tok_encode(self, string: str) -> List[int]:
119
        return self.tokenizer.encode(string, add_special_tokens=False)
Fabrizio Milo's avatar
Fabrizio Milo committed
120

baberabb's avatar
baberabb committed
121
    def tok_decode(self, tokens: List[int]) -> str:
122
        return self.tokenizer.decode(tokens)
Leo Gao's avatar
Leo Gao committed
123

baberabb's avatar
baberabb committed
124
125
126
127
128
129
130
131
132
133
134
135
136
    def _encode_pair(
        self, context: str, continuation: str
    ) -> Tuple[List[int], List[int]]:
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]
        whole_enc = self.tok_encode(context + continuation)
        context_enc = self.tok_encode(context)
        context_enc_len = len(context_enc)
        continuation_enc = whole_enc[context_enc_len:]
        return context_enc, continuation_enc

137
    def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
baberabb's avatar
baberabb committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
                context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
                    continuation
                )
            else:
                context_enc, continuation_enc = self._encode_pair(context, continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

152
153
154
    def _loglikelihood_tokens(
        self, requests, disable_tqdm=False
    ) -> List[Tuple[float, bool]]:
Leo Gao's avatar
Leo Gao committed
155
156
        res = []

157
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
158
159
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
160
            # we care about, and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
161
            toks = x[1] + x[2]
162
            return -len(toks), tuple(toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
163

Fabrizio Milo's avatar
Fabrizio Milo committed
164
        re_ord = utils.Reorderer(requests, _collate)
Jason Phang's avatar
Jason Phang committed
165

Fabrizio Milo's avatar
Fabrizio Milo committed
166
        for chunk in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
167
            list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
Fabrizio Milo's avatar
Fabrizio Milo committed
168
169
            disable=disable_tqdm,
        ):
Leo Gao's avatar
Leo Gao committed
170
171
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
172
            for cache_key, context_enc, continuation_enc in chunk:
173
                # max_length+1 because the API takes up to 2049 tokens, including the first context token
Fabrizio Milo's avatar
Fabrizio Milo committed
174
                inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
175
                # TODO: the logic is much simpler if we just look at the length of continuation tokens
Fabrizio Milo's avatar
Fabrizio Milo committed
176
177
178
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
                )
Leo Gao's avatar
Leo Gao committed
179
180
181
182

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
183
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
184
185
186
                engine=self.engine,
                prompt=inps,
                echo=True,
Fabrizio Milo's avatar
Fabrizio Milo committed
187
188
                max_tokens=0,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
189
190
191
                logprobs=10,
            )

Fabrizio Milo's avatar
Fabrizio Milo committed
192
193
194
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
                response.choices, ctxlens, chunk
            ):
Leo Gao's avatar
Leo Gao committed
195
196
197
198
199
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
200
201
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Fabrizio Milo's avatar
Fabrizio Milo committed
202
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
203

baberabb's avatar
baberabb committed
204
    def greedy_until(self, requests) -> List[str]:
205
206
        if not requests:
            return []
Leo Gao's avatar
Leo Gao committed
207
        res = []
208
        requests = [req.args for req in requests]
Leo Gao's avatar
Leo Gao committed
209

210
        def _collate(x):
211
            toks = self.tok_encode(x[0])
212
            return len(toks), x[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
213

Fabrizio Milo's avatar
Fabrizio Milo committed
214
        re_ord = utils.Reorderer(requests, _collate)
215

Leo Gao's avatar
Leo Gao committed
216
217
218
219
220
221
222
223
224
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
Fabrizio Milo's avatar
Fabrizio Milo committed
225

226
227
            if ret:
                yield ret, lastuntil
Leo Gao's avatar
Leo Gao committed
228

229
        # todo: more intelligent batching for heterogeneous `until`
230
        for chunk, request_args in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
231
            list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
Fabrizio Milo's avatar
Fabrizio Milo committed
232
        ):
Leo Gao's avatar
Leo Gao committed
233
234
            inps = []
            for context, _ in chunk:
235
                context_enc = self.tok_encode(context)
Fabrizio Milo's avatar
Fabrizio Milo committed
236
                inp = context_enc[-(self.max_length - self.max_gen_toks) :]
Leo Gao's avatar
Leo Gao committed
237
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
238

239
240
241
242
243
244
245
            try:
                until = request_args["until"][
                    0
                ]  # TODO: does this handle a list of stop seqs correctly?
            except KeyError:
                until = "<|endoftext|>"

Leo Gao's avatar
Leo Gao committed
246
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
247
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
248
                prompt=inps,
Fabrizio Milo's avatar
Fabrizio Milo committed
249
250
                max_tokens=self.max_gen_toks,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
251
                logprobs=10,
252
                stop=until,
Leo Gao's avatar
Leo Gao committed
253
            )
Leo Gao's avatar
Leo Gao committed
254

255
            for resp, (context, args_) in zip(response.choices, chunk):
Fabrizio Milo's avatar
Fabrizio Milo committed
256
                s = resp["text"]
Leo Gao's avatar
Leo Gao committed
257

258
                until_ = args_.get("until", [])
259

260
                for term in until_:
261
262
                    if len(term) > 0:
                        s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
263

Leo Gao's avatar
Leo Gao committed
264
                # partial caching
265
266
267
                self.cache_hook.add_partial(
                    "greedy_until", (context, {"until": until_}), s
                )
Fabrizio Milo's avatar
Fabrizio Milo committed
268

Leo Gao's avatar
Leo Gao committed
269
                res.append(s)
Fabrizio Milo's avatar
Fabrizio Milo committed
270
        return re_ord.get_original(res)
271
272
273
274
275
276
277
278

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Isn't used because we override greedy_until
        raise NotImplementedError()
baberabb's avatar
baberabb committed
279

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    def loglikelihood_rolling(self, requests) -> List[float]:
        loglikelihoods = []

        for (string,) in tqdm([req.args for req in requests]):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows,
                disable_tqdm=True,
            )

            # discard is_greedy
            string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)
        return loglikelihoods