openai_completions.py 13.5 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
import time
baberabb's avatar
baberabb committed
3
from typing import List, Tuple
Leo Gao's avatar
Leo Gao committed
4
from tqdm import tqdm
lintangsutawika's avatar
lintangsutawika committed
5
from lm_eval import utils
6
7
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
Leo Gao's avatar
Leo Gao committed
8
9


baberabb's avatar
baberabb committed
10
def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]:
11
12
13
14
15
16
17
18
19
20
21
22
    """Process results from OpenAI API response.

    :param response: dict
        OpenAI API Response
    :param ctxlen: int
        Length of context (so we can slice them away and only keep the predictions)
    :return:
        continuation_logprobs: np.array
            Log probabilities of continuation tokens
        is_greedy: bool
            whether argmax matches given continuation exactly
    """
Leo Gao's avatar
Leo Gao committed
23
24
25
26
27
28
29
30
31
32
33
    is_greedy = True
    logprobs = response["logprobs"]["token_logprobs"]
    continuation_logprobs = sum(logprobs[ctxlen:])

    for i in range(ctxlen, len(response["logprobs"]["tokens"])):
        token = response["logprobs"]["tokens"][i]
        top_tokens = response["logprobs"]["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break
Fabrizio Milo's avatar
Fabrizio Milo committed
34

Leo Gao's avatar
Leo Gao committed
35
    return continuation_logprobs, is_greedy
Jason Phang's avatar
gpt3  
Jason Phang committed
36
37


Leo Gao's avatar
Leo Gao committed
38
def oa_completion(**kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
39
    """Query OpenAI API for completion.
Leo Gao's avatar
Leo Gao committed
40

41
42
    Retry with back-off until they respond
    """
43
44
45
46
47
48
49
    try:
        import openai, tiktoken  # noqa: E401
    except ModuleNotFoundError:
        raise Exception(
            "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
50

Leo Gao's avatar
Leo Gao committed
51
52
53
54
55
    backoff_time = 3
    while True:
        try:
            return openai.Completion.create(**kwargs)
        except openai.error.OpenAIError:
Leo Gao's avatar
Leo Gao committed
56
            import traceback
Fabrizio Milo's avatar
Fabrizio Milo committed
57

Leo Gao's avatar
Leo Gao committed
58
            traceback.print_exc()
Leo Gao's avatar
Leo Gao committed
59
60
61
62
            time.sleep(backoff_time)
            backoff_time *= 1.5


63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def oa_chat_completion(is_async: bool = False, **kwargs):
    """Query OpenAI API for chat completion.

    Retry with back-off until they respond
    """
    try:
        import openai, tiktoken  # noqa: E401
    except ModuleNotFoundError:
        raise Exception(
            "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
        )

    backoff_time = 3
    while True:
        try:
            if is_async:
                return openai.ChatCompletion.acreate(**kwargs)
            else:
                return openai.ChatCompletion.create(**kwargs)
        except openai.error.OpenAIError:
            import traceback

            traceback.print_exc()
            time.sleep(backoff_time)
            backoff_time *= 1.5


async def oa_chat_completion_async(**kwargs):
    """Query async OpenAI API for chat completion.

    Retry with back-off until they respond
    """
    completion = await oa_chat_completion(is_async=True, **kwargs)

    return completion


haileyschoelkopf's avatar
haileyschoelkopf committed
101
@register_model("openai", "openai-completions", "gooseai")
haileyschoelkopf's avatar
haileyschoelkopf committed
102
class OpenaiCompletionsLM(LM):
Leo Gao's avatar
Leo Gao committed
103
    REQ_CHUNK_SIZE = 20
Jason Phang's avatar
Jason Phang committed
104

baberabb's avatar
baberabb committed
105
106
107
108
109
    def __init__(
        self,
        engine: str = "text-davinci-003",
        truncate: bool = False,
        batch_size: int = 1,
Ethan Smith's avatar
Ethan Smith committed
110
    ) -> None:
Jason Phang's avatar
Jason Phang committed
111
112
113
114
115
116
117
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
Leo Gao's avatar
Leo Gao committed
118
        super().__init__()
119
120
121
122
123
124
125
        try:
            import openai, tiktoken  # noqa: E401
        except ModuleNotFoundError:
            raise Exception(
                "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
    please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
            )
Jason Phang's avatar
gpt3  
Jason Phang committed
126
        self.engine = engine
127
128
        self.tokenizer = tiktoken.encoding_for_model(self.engine)
        self.vocab_size = self.tokenizer.n_vocab
Jason Phang's avatar
Jason Phang committed
129
        self.truncate = truncate
130
        self.end_of_text_token_id = self.tokenizer.eot_token
Jason Phang's avatar
Jason Phang committed
131

Jason Phang's avatar
gpt3  
Jason Phang committed
132
133
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
134
135
136

    @property
    def eot_token_id(self):
137
        return self.end_of_text_token_id
138
139

    @property
Ethan Smith's avatar
Ethan Smith committed
140
    def max_length(self) -> int:
141
142
143
144
        # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
        return 2048

    @property
Ethan Smith's avatar
Ethan Smith committed
145
    def max_gen_toks(self) -> int:
146
147
148
149
150
151
152
153
154
155
156
157
        return 256

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

baberabb's avatar
baberabb committed
158
    def tok_encode(self, string: str) -> List[int]:
159
        return self.tokenizer.encode(string)
Fabrizio Milo's avatar
Fabrizio Milo committed
160

baberabb's avatar
baberabb committed
161
    def tok_decode(self, tokens: List[int]) -> str:
162
        return self.tokenizer.decode(tokens)
Leo Gao's avatar
Leo Gao committed
163

baberabb's avatar
baberabb committed
164
165
166
167
168
169
170
171
172
173
174
175
176
    def _encode_pair(
        self, context: str, continuation: str
    ) -> Tuple[List[int], List[int]]:
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuation = context[-n_spaces:] + continuation
            context = context[:-n_spaces]
        whole_enc = self.tok_encode(context + continuation)
        context_enc = self.tok_encode(context)
        context_enc_len = len(context_enc)
        continuation_enc = whole_enc[context_enc_len:]
        return context_enc, continuation_enc

177
    def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
baberabb's avatar
baberabb committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        new_reqs = []
        for context, continuation in [req.args for req in requests]:
            if context == "":
                # end of text as context
                context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
                    continuation
                )
            else:
                context_enc, continuation_enc = self._encode_pair(context, continuation)

            new_reqs.append(((context, continuation), context_enc, continuation_enc))

        return self._loglikelihood_tokens(new_reqs)

192
    def _loglikelihood_tokens(
Ethan Smith's avatar
Ethan Smith committed
193
        self, requests, disable_tqdm: bool = False
194
    ) -> List[Tuple[float, bool]]:
Leo Gao's avatar
Leo Gao committed
195
196
        res = []

197
        def _collate(x):
Leo Gao's avatar
Leo Gao committed
198
199
            # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
            # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
200
            # we care about, and so we need some kind of backup for when it isn't
Leo Gao's avatar
Leo Gao committed
201
            toks = x[1] + x[2]
202
            return -len(toks), tuple(toks)
Fabrizio Milo's avatar
Fabrizio Milo committed
203

Fabrizio Milo's avatar
Fabrizio Milo committed
204
        re_ord = utils.Reorderer(requests, _collate)
Jason Phang's avatar
Jason Phang committed
205

Fabrizio Milo's avatar
Fabrizio Milo committed
206
        for chunk in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
207
            list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
Fabrizio Milo's avatar
Fabrizio Milo committed
208
209
            disable=disable_tqdm,
        ):
Leo Gao's avatar
Leo Gao committed
210
211
            inps = []
            ctxlens = []
Leo Gao's avatar
Leo Gao committed
212
            for cache_key, context_enc, continuation_enc in chunk:
213
                # max_length+1 because the API takes up to 2049 tokens, including the first context token
Fabrizio Milo's avatar
Fabrizio Milo committed
214
                inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
215
                # TODO: the logic is much simpler if we just look at the length of continuation tokens
Fabrizio Milo's avatar
Fabrizio Milo committed
216
217
218
                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
                )
Leo Gao's avatar
Leo Gao committed
219
220
221
222

                inps.append(inp)
                ctxlens.append(ctxlen)

Leo Gao's avatar
Leo Gao committed
223
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
224
225
226
                engine=self.engine,
                prompt=inps,
                echo=True,
Fabrizio Milo's avatar
Fabrizio Milo committed
227
228
                max_tokens=0,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
229
230
231
                logprobs=10,
            )

Fabrizio Milo's avatar
Fabrizio Milo committed
232
233
234
            for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
                response.choices, ctxlens, chunk
            ):
Leo Gao's avatar
Leo Gao committed
235
236
237
238
239
                answer = get_result(resp, ctxlen)

                res.append(answer)

                # partial caching
Leo Gao's avatar
Leo Gao committed
240
241
                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)
Fabrizio Milo's avatar
Fabrizio Milo committed
242
        return re_ord.get_original(res)
Leo Gao's avatar
Leo Gao committed
243

baberabb's avatar
baberabb committed
244
    def greedy_until(self, requests) -> List[str]:
245
246
        if not requests:
            return []
Leo Gao's avatar
Leo Gao committed
247
        res = []
248
        requests = [req.args for req in requests]
Leo Gao's avatar
Leo Gao committed
249

250
        def _collate(x):
251
            toks = self.tok_encode(x[0])
252
            return len(toks), x[0]
Fabrizio Milo's avatar
Fabrizio Milo committed
253

Fabrizio Milo's avatar
Fabrizio Milo committed
254
        re_ord = utils.Reorderer(requests, _collate)
255

Leo Gao's avatar
Leo Gao committed
256
257
258
259
260
261
262
263
264
        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)
Fabrizio Milo's avatar
Fabrizio Milo committed
265

266
267
            if ret:
                yield ret, lastuntil
Leo Gao's avatar
Leo Gao committed
268

269
        # todo: more intelligent batching for heterogeneous `until`
270
        for chunk, request_args in tqdm(
Fabrizio Milo's avatar
Fabrizio Milo committed
271
            list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
Fabrizio Milo's avatar
Fabrizio Milo committed
272
        ):
Leo Gao's avatar
Leo Gao committed
273
274
            inps = []
            for context, _ in chunk:
275
                context_enc = self.tok_encode(context)
Fabrizio Milo's avatar
Fabrizio Milo committed
276
                inp = context_enc[-(self.max_length - self.max_gen_toks) :]
Leo Gao's avatar
Leo Gao committed
277
                inps.append(inp)
Leo Gao's avatar
Leo Gao committed
278

279
            until = request_args.get("until", ["<|endoftext|>"])
280

Leo Gao's avatar
Leo Gao committed
281
            response = oa_completion(
Leo Gao's avatar
Leo Gao committed
282
                engine=self.engine,
Leo Gao's avatar
Leo Gao committed
283
                prompt=inps,
Fabrizio Milo's avatar
Fabrizio Milo committed
284
285
                max_tokens=self.max_gen_toks,
                temperature=0.0,
Leo Gao's avatar
Leo Gao committed
286
                logprobs=10,
287
                stop=until,
Leo Gao's avatar
Leo Gao committed
288
            )
Leo Gao's avatar
Leo Gao committed
289

290
            for resp, (context, args_) in zip(response.choices, chunk):
Fabrizio Milo's avatar
Fabrizio Milo committed
291
                s = resp["text"]
Leo Gao's avatar
Leo Gao committed
292

293
                until_ = args_.get("until", ["<|endoftext|>"])
294

295
                for term in until_:
296
297
                    if len(term) > 0:
                        s = s.split(term)[0]
Leo Gao's avatar
Leo Gao committed
298

Leo Gao's avatar
Leo Gao committed
299
                # partial caching
300
301
302
                self.cache_hook.add_partial(
                    "greedy_until", (context, {"until": until_}), s
                )
Fabrizio Milo's avatar
Fabrizio Milo committed
303

Leo Gao's avatar
Leo Gao committed
304
                res.append(s)
Fabrizio Milo's avatar
Fabrizio Milo committed
305
        return re_ord.get_original(res)
306
307
308
309
310
311
312
313

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Isn't used because we override greedy_until
        raise NotImplementedError()
baberabb's avatar
baberabb committed
314

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    def loglikelihood_rolling(self, requests) -> List[float]:
        loglikelihoods = []

        for (string,) in tqdm([req.args for req in requests]):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.eot_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            string_nll = self._loglikelihood_tokens(
                rolling_token_windows,
                disable_tqdm=True,
            )

            # discard is_greedy
            string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)
        return loglikelihoods
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415


@register_model("openai", "openai-chat-completions", "gooseai")
class OpenaiChatCompletionsLM(OpenaiCompletionsLM):
    def __init__(
        self, engine: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
    ) -> None:
        super().__init__(engine, truncate, batch_size)

    def greedy_until(self, requests) -> List[str]:
        if not requests:
            return []
        res = []
        requests = [req.args for req in requests]

        def _collate(x):
            toks = self.tok_encode(x[0])
            return len(toks), x[0]

        re_ord = utils.Reorderer(requests, _collate)

        def sameuntil_chunks(xs, size):
            ret = []
            lastuntil = xs[0][1]
            for x in xs:
                if len(ret) >= size or x[1] != lastuntil:
                    yield ret, lastuntil
                    ret = []
                    lastuntil = x[1]
                ret.append(x)

            if ret:
                yield ret, lastuntil

        # todo: more intelligent batching for heterogeneous `until`
        for chunk, request_args in tqdm(
            list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
        ):
            inps = []
            for context, _ in chunk:
                context_enc = self.tok_encode(context)
                inp = context_enc[-(self.max_length - self.max_gen_toks) :]
                inps.append({"role": "user", "content": inp})

            until = request_args.get("until", ["<|endoftext|>"])

            response = oa_completion(
                engine=self.engine,
                prompt=inps,
                max_tokens=self.max_gen_toks,
                temperature=0.0,
                logprobs=10,
                stop=until,
            )

            for resp, (context, args_) in zip(response.choices, chunk):
                s = resp["text"]

                until_ = args_.get("until", ["<|endoftext|>"])

                for term in until_:
                    if len(term) > 0:
                        s = s.split(term)[0]

                # partial caching
                self.cache_hook.add_partial(
                    "greedy_until", (context, {"until": until_}), s
                )

                res.append(s)
        return re_ord.get_original(res)